Updated PO files.
[virt-mem.git] / lib / virt_mem_mmap.ml
index 3b84a26..b9013de 100644 (file)
  *)
 
 open Unix
+open Printf
 open Bigarray
 
 open Virt_mem_utils
 
-(* Simple implementation at the moment: Store a list of mappings,
- * sorted by start address.  We assume that mappings do not overlap.
- * We can change the implementation later if we need to.  In most cases
- * there will only be a small number of mappings (probably 1).
- *)
-type ('a,'b) t = {
-  mappings : mapping list;
-  wordsize : wordsize option;
-  endian : Bitstring.endian option;
-}
-and mapping = {
+let debug = true
+
+(* An address. *)
+type addr = int64
+
+(* A range of addresses (start and start+size). *)
+type interval = addr * addr
+
+(* A mapping. *)
+type mapping = {
   start : addr;
   size : addr;
   (* Bigarray mmap(2)'d region with byte addressing: *)
   arr : (char,int8_unsigned_elt,c_layout) Array1.t;
+  (* The order that the mappings were added, 0 for the first mapping,
+   * 1 for the second mapping, etc.
+   *)
+  order : int;
 }
 
-and addr = int64
+(* A memory map. *)
+type ('ws,'e,'hm) t = {
+  (* List of mappings, kept in reverse order they were added (new
+   * mappings are added at the head of this list).
+   *)
+  mappings : mapping list;
+
+  (* Segment tree for fast access to a mapping at a particular address.
+   * This is rebuilt each time a new mapping is added.
+   * NB! If mappings = [], ignore contents of this field.  (This is
+   * enforced by the 'hm phantom type).
+   *)
+  tree : (interval * mapping list, interval * mapping list) binary_tree;
+
+  (* Word size, endianness.
+   * Phantom types enforce that these are set before being used.
+   *)
+  wordsize : wordsize;
+  endian : Bitstring.endian;
+}
 
 let create () = {
   mappings = [];
-  wordsize = None;
-  endian = None
+  tree = Leaf ((0L,0L),[]);
+  wordsize = W32;
+  endian = Bitstring.LittleEndian;
 }
 
-let set_wordsize t ws = { t with wordsize = Some ws }
+let set_wordsize t ws = { t with wordsize = ws }
+
+let set_endian t e = { t with endian = e }
+
+let get_wordsize t = t.wordsize
+
+let get_endian t = t.endian
+
+(* Build the segment tree from the list of mappings.  This code
+ * is taken from virt-df.  For an explanation of the process see:
+ * http://en.wikipedia.org/wiki/Segment_tree
+ *)
+let tree_of_mappings mappings =
+  (* Construct the list of distinct endpoints. *)
+  let eps =
+    List.map
+      (fun { start = start; size = size } -> [start; start +^ size])
+      mappings in
+  let eps = sort_uniq (List.concat eps) in
+
+  (* Construct the elementary intervals. *)
+  let elints =
+    let elints, lastpoint =
+      List.fold_left (
+       fun (elints, prevpoint) point ->
+         ((point, point) :: (prevpoint, point) :: elints), point
+      ) ([], 0L) eps in
+    let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
+    List.rev elints in
+
+  if debug then (
+    eprintf "elementary intervals (%d in total):\n" (List.length elints);
+    List.iter (
+      fun (startpoint, endpoint) ->
+       eprintf "  %Lx %Lx\n" startpoint endpoint
+    ) elints
+  );
+
+  (* Construct the binary tree of elementary intervals. *)
+  let tree =
+    (* Each elementary interval becomes a leaf. *)
+    let elints = List.map (fun elint -> Leaf elint) elints in
+    (* Recursively build this into a binary tree. *)
+    let rec make_layer = function
+      | [] -> []
+      | ([_] as x) -> x
+      (* Turn pairs of leaves at the bottom level into nodes. *)
+      | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
+         let xs = make_layer xs in
+         Node (a, (), b) :: xs
+      (* Turn pairs of nodes at higher levels into nodes. *)
+      | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
+         let xs = make_layer xs in
+         Node (left, (), right) :: xs
+      | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
+    in
+    let rec loop = function
+      | [] -> assert false
+      | [x] -> x
+      | xs -> loop (make_layer xs)
+    in
+    loop elints in
 
-let set_endian t e = { t with endian = Some e }
+  if debug then (
+    let leaf_printer (startpoint, endpoint) =
+      sprintf "%Lx-%Lx" startpoint endpoint
+    in
+    let node_printer () = "" in
+    print_binary_tree leaf_printer node_printer tree
+  );
+
+  (* Insert the mappings into the tree one by one. *)
+  let tree =
+    (* For each node/leaf in the tree, add its interval and an
+     * empty list which will be used to store the mappings.
+     *)
+    let rec interval_tree = function
+      | Leaf elint -> Leaf (elint, [])
+      | Node (left, (), right) ->
+         let left = interval_tree left in
+         let right = interval_tree right in
+         let (leftstart, _) = interval_of_node left in
+         let (_, rightend) = interval_of_node right in
+         let interval = leftstart, rightend in
+         Node (left, (interval, []), right)
+    and interval_of_node = function
+      | Leaf (elint, _) -> elint
+      | Node (_, (interval, _), _) -> interval
+    in
 
-let get_wordsize t = Option.get t.wordsize
+    let tree = interval_tree tree in
+    (* This should always be true: *)
+    assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
+
+    (* "Contained in" operator.
+     * 'a <-< b' iff 'a' is a subinterval of 'b'.
+     *      |<---- a ---->|
+     * |<----------- b ----------->|
+     *)
+    let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
+
+    (* "Intersects" operator.
+     * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
+     *      |<---- a ---->|
+     *                |<----------- b ----------->|
+     *)
+    let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
+
+    let rec insert_mapping tree mapping =
+      let { start = start; size = size } = mapping in
+      let seginterval = start, start +^ size in
+
+      match tree with
+      (* Test if we should insert into this leaf or node: *)
+      | Leaf (interval, mappings) when interval <-< seginterval ->
+         Leaf (interval, mapping :: mappings)
+      | Node (left, (interval, mappings), right)
+         when interval <-< seginterval ->
+         Node (left, (interval, mapping :: mappings), right)
+
+      | (Leaf _) as leaf -> leaf
+
+      (* Else, should we insert into left or right subtrees? *)
+      | Node (left, i, right) ->
+         let left =
+           if seginterval /\ interval_of_node left then
+             insert_mapping left mapping
+           else
+             left in
+         let right =
+           if seginterval /\ interval_of_node right then
+             insert_mapping right mapping
+           else
+             right in
+         Node (left, i, right)
+    in
+    let tree = List.fold_left insert_mapping tree mappings in
+    tree in
+
+  if debug then (
+    let printer ((sp, ep), mappings) =
+      sprintf "[%Lx-%Lx] " sp ep ^
+       String.concat ";"
+       (List.map (fun { start = start; size = size } ->
+                    sprintf "%Lx+%Lx" start size)
+          mappings)
+    in
+    print_binary_tree printer printer tree
+  );
 
-let get_endian t = Option.get t.endian
+  tree
 
-let sort_mappings mappings =
-  let cmp { start = s1 } { start = s2 } = compare s1 s2 in
-  List.sort cmp mappings
+let add_mapping ({ mappings = mappings } as t) start size arr =
+  let order = List.length mappings in
+  let mapping = { start = start; size = size; arr = arr; order = order } in
+  let mappings = mapping :: mappings in
+  let tree = tree_of_mappings mappings in
+  { t with mappings = mappings; tree = tree }
 
-let add_file ({ mappings = mappings } as t) fd addr =
-  if addr &^ 7L <> 0L then
-    invalid_arg "add_file: mapping address must be aligned to 8 bytes";
+let add_file t fd addr =
   let size = (fstat fd).st_size in
   (* mmap(2) the file using Bigarray module. *)
   let arr = Array1.map_file fd char c_layout false size in
-  (* Create the mapping entry and keep the mappings sorted by start addr. *)
-  let mappings =
-    { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
-  let mappings = sort_mappings mappings in
-  { t with mappings = mappings }
-
-let of_file fd addr =
-  let t = create () in
-  add_file t fd addr
+  (* Create the mapping entry. *)
+  add_mapping t addr (Int64.of_int size) arr
 
 let add_string ({ mappings = mappings } as t) str addr =
-  if addr &^ 7L <> 0L then
-    invalid_arg "add_file: mapping address must be aligned to 8 bytes";
   let size = String.length str in
   (* Copy the string data to a Bigarray. *)
   let arr = Array1.create char c_layout size in
   for i = 0 to size-1 do
     Array1.set arr i (String.unsafe_get str i)
   done;
-  (* Create the mapping entry and keep the mappings sorted by start addr. *)
-  let mappings =
-    { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
-  let mappings = sort_mappings mappings in
-  { t with mappings = mappings }
+  (* Create the mapping entry. *)
+  add_mapping t addr (Int64.of_int size) arr
+
+let of_file fd addr =
+  let t = create () in
+  add_file t fd addr
 
 let of_string str addr =
   let t = create () in
   add_string t str addr
 
+(* Look up an address and get the top-most mapping which contains it.
+ * This uses the segment tree, so it's fast.  The top-most mapping is
+ * the one with the highest 'order' field.
+ *
+ * Warning: This 'hot' code was carefully optimized based on
+ * feedback from 'gprof'.  Avoid fiddling with it.
+ *)
+let rec get_mapping addr = function
+  | Leaf (_, []) -> None
+  | Leaf (_, [mapping]) -> Some mapping
+  | Leaf (_, mappings) -> Some (find_highest_order mappings)
+
+  (* Try to avoid expensive search if node mappings is empty: *)
+  | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+         (_, []),
+         right) ->
+      let submapping =
+       if addr < leftend then get_mapping addr left
+       else get_mapping addr right in
+      submapping
+
+  (* ... or a singleton: *)
+  | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+         (_, [mapping]),
+         right) ->
+      let submapping =
+       if addr < leftend then get_mapping addr left
+       else get_mapping addr right in
+      (match submapping with
+       | None -> Some mapping
+       | Some submapping ->
+          Some (if mapping.order > submapping.order then mapping
+                else submapping)
+      )
+
+  (* Normal recursive case: *)
+  | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+         (_, mappings),
+         right) ->
+      let submapping =
+       if addr < leftend then get_mapping addr left
+       else get_mapping addr right in
+      (match submapping with
+       | None -> Some (find_highest_order mappings)
+       | Some submapping -> Some (find_highest_order (submapping :: mappings))
+      )
+
+and find_highest_order mappings =
+  List.fold_left (
+    fun mapping1 mapping2 ->
+      if mapping1.order > mapping2.order then mapping1 else mapping2
+  ) (List.hd mappings) (List.tl mappings)
+
+(* Get a single byte. *)
+let get_byte { tree = tree } addr =
+  (* Get the mapping which applies to this address: *)
+  match get_mapping addr tree with
+  | Some { start = start; size = size; arr = arr } ->
+      let offset = Int64.to_int (addr -^ start) in
+      Char.code (Array1.get arr offset)
+  | None ->
+      invalid_arg "get_byte"
+(*
+  let rec loop = function
+    | [] -> invalid_arg "get_byte"
+    | { start = start; size = size; arr = arr } :: _
+       when start <= addr && addr < start +^ size ->
+       let offset = Int64.to_int (addr -^ start) in
+       Char.code (Array1.get arr offset)
+    | _ :: ms -> loop ms
+  in
+  loop mappings
+*)
+
+
+(*
+
 (* Find in mappings and return first predicate match. *)
 let _find_map { mappings = mappings } pred =
   let rec loop = function
@@ -233,17 +471,6 @@ and addr_of_string t str =
   | { addr : bits : endian (e) } -> addr
   | { _ } -> invalid_arg "addr_of_string"
 
-let get_byte { mappings = mappings } addr =
-  let rec loop = function
-    | [] -> invalid_arg "get_byte"
-    | { start = start; size = size; arr = arr } :: _
-       when start <= addr && addr < start +^ size ->
-       let offset = Int64.to_int (addr -^ start) in
-       Char.code (Array1.get arr offset)
-    | _ :: ms -> loop ms
-  in
-  loop mappings
-
 (* Take bytes until a condition is not met.  This is efficient in that
  * we stay within the same mapping as long as we can.
  *)
@@ -390,3 +617,12 @@ let align t addr =
   let ws = get_wordsize t in
   let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
   (addr +^ mask) &^ (Int64.lognot mask)
+
+let map { mappings = mappings } f =
+  List.map (fun { start = start; size = size } -> f start size) mappings
+
+let iter t f =
+  ignore (map t (fun start size -> let () = f start size in ()))
+
+let nr_mappings { mappings = mappings } = List.length mappings
+*)