Extracted kernel structures for device addressing in ifconfig.
[virt-mem.git] / lib / virt_mem_mmap.ml
index 7401e17..c323c1a 100644 (file)
  *)
 
 open Unix
+open Printf
 open Bigarray
 
 open Virt_mem_utils
 
-(* Simple implementation at the moment: Store a list of mappings,
- * sorted by start address.  We assume that mappings do not overlap.
- * We can change the implementation later if we need to.  In most cases
- * there will only be a small number of mappings (probably 1).
- *)
-type ('a,'b) t = {
-  mappings : mapping list;
-  wordsize : wordsize option;
-  endian : Bitmatch.endian option;
-}
-and mapping = {
+let debug = false
+
+(* An address. *)
+type addr = int64
+
+(* A range of addresses (start and start+size). *)
+type interval = addr * addr
+
+(* A mapping. *)
+type mapping = {
   start : addr;
   size : addr;
   (* Bigarray mmap(2)'d region with byte addressing: *)
   arr : (char,int8_unsigned_elt,c_layout) Array1.t;
+  (* The order that the mappings were added, 0 for the first mapping,
+   * 1 for the second mapping, etc.
+   *)
+  order : int;
 }
 
-and addr = int64
+(* A memory map. *)
+type ('ws,'e,'hm) t = {
+  (* List of mappings, kept in reverse order they were added (new
+   * mappings are added at the head of this list).
+   *)
+  mappings : mapping list;
+
+  (* Segment tree for fast access to a mapping at a particular address.
+   * This is rebuilt each time a new mapping is added.
+   * NB! If mappings = [], ignore contents of this field.  (This is
+   * enforced by the 'hm phantom type).
+   *)
+  tree : (interval * mapping option, interval * mapping option) binary_tree;
+
+  (* Word size, endianness.
+   * Phantom types enforce that these are set before being used.
+   *)
+  wordsize : wordsize;
+  endian : Bitstring.endian;
+}
 
 let create () = {
   mappings = [];
-  wordsize = None;
-  endian = None
+  tree = Leaf ((0L,0L),None);
+  wordsize = W32;
+  endian = Bitstring.LittleEndian;
 }
 
-let set_wordsize t ws = { t with wordsize = Some ws }
+let set_wordsize t ws = { t with wordsize = ws }
 
-let set_endian t e = { t with endian = Some e }
+let set_endian t e = { t with endian = e }
 
-let get_wordsize t = Option.get t.wordsize
+let get_wordsize t = t.wordsize
 
-let get_endian t = Option.get t.endian
+let get_endian t = t.endian
 
-let sort_mappings mappings =
-  let cmp { start = s1 } { start = s2 } = compare s1 s2 in
-  List.sort cmp mappings
+(* Build the segment tree from the list of mappings.  This code
+ * is taken from virt-df.  For an explanation of the process see:
+ * http://en.wikipedia.org/wiki/Segment_tree
+ *
+ * See also the 'get_mapping' function below which uses this tree
+ * to do fast lookups.
+ *)
+let tree_of_mappings mappings =
+  (* Construct the list of distinct endpoints. *)
+  let eps =
+    List.map
+      (fun { start = start; size = size } -> [start; start +^ size])
+      mappings in
+  let eps = sort_uniq (List.concat eps) in
+
+  (* Construct the elementary intervals. *)
+  let elints =
+    let elints, lastpoint =
+      List.fold_left (
+       fun (elints, prevpoint) point ->
+         ((point, point) :: (prevpoint, point) :: elints), point
+      ) ([], 0L) eps in
+    let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
+    List.rev elints in
+
+  if debug then (
+    eprintf "elementary intervals (%d in total):\n" (List.length elints);
+    List.iter (
+      fun (startpoint, endpoint) ->
+       eprintf "  %Lx %Lx\n" startpoint endpoint
+    ) elints
+  );
+
+  (* Construct the binary tree of elementary intervals. *)
+  let tree =
+    (* Each elementary interval becomes a leaf. *)
+    let elints = List.map (fun elint -> Leaf elint) elints in
+    (* Recursively build this into a binary tree. *)
+    let rec make_layer = function
+      | [] -> []
+      | ([_] as x) -> x
+      (* Turn pairs of leaves at the bottom level into nodes. *)
+      | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
+         let xs = make_layer xs in
+         Node (a, (), b) :: xs
+      (* Turn pairs of nodes at higher levels into nodes. *)
+      | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
+         let xs = make_layer xs in
+         Node (left, (), right) :: xs
+      | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
+    in
+    let rec loop = function
+      | [] -> assert false
+      | [x] -> x
+      | xs -> loop (make_layer xs)
+    in
+    loop elints in
 
-let add_file ({ mappings = mappings } as t) fd addr =
-  if addr &^ 7L <> 0L then
-    invalid_arg "add_file: mapping address must be aligned to 8 bytes";
+  if debug then (
+    let leaf_printer (startpoint, endpoint) =
+      sprintf "%Lx-%Lx" startpoint endpoint
+    in
+    let node_printer () = "" in
+    print_binary_tree leaf_printer node_printer tree
+  );
+
+  (* Insert the mappings into the tree one by one. *)
+  let tree =
+    (* For each node/leaf in the tree, add its interval and an
+     * empty list which will be used to store the mappings.
+     *)
+    let rec interval_tree = function
+      | Leaf elint -> Leaf (elint, None)
+      | Node (left, (), right) ->
+         let left = interval_tree left in
+         let right = interval_tree right in
+         let (leftstart, _) = interval_of_node left in
+         let (_, rightend) = interval_of_node right in
+         let interval = leftstart, rightend in
+         Node (left, (interval, None), right)
+    and interval_of_node = function
+      | Leaf (elint, _) -> elint
+      | Node (_, (interval, _), _) -> interval
+    in
+
+    let tree = interval_tree tree in
+    (* This should always be true: *)
+    assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
+
+    (* "Contained in" operator.
+     * 'a <-< b' iff 'a' is a subinterval of 'b'.
+     *      |<---- a ---->|
+     * |<----------- b ----------->|
+     *)
+    let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
+
+    (* "Intersects" operator.
+     * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
+     *      |<---- a ---->|
+     *                |<----------- b ----------->|
+     *)
+    let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
+
+    let rec insert_mapping tree mapping =
+      let { start = start; size = size } = mapping in
+      let seginterval = start, start +^ size in
+
+      match tree with
+      (* Test if we should insert into this leaf or node: *)
+      | Leaf (interval, None) when interval <-< seginterval ->
+         Leaf (interval, Some mapping)
+      | Leaf (interval, Some oldmapping) when interval <-< seginterval ->
+         let mapping =
+           if oldmapping.order > mapping.order then oldmapping else mapping in
+         Leaf (interval, Some mapping)
+
+      | Node (left, (interval, None), right) when interval <-< seginterval ->
+         Node (left, (interval, Some mapping), right)
+
+      | Node (left, (interval, Some oldmapping), right)
+         when interval <-< seginterval ->
+         let mapping =
+           if oldmapping.order > mapping.order then oldmapping else mapping in
+         Node (left, (interval, Some mapping), right)
+
+      | (Leaf _) as leaf -> leaf
+
+      (* Else, should we insert into left or right subtrees? *)
+      | Node (left, i, right) ->
+         let left =
+           if seginterval /\ interval_of_node left then
+             insert_mapping left mapping
+           else
+             left in
+         let right =
+           if seginterval /\ interval_of_node right then
+             insert_mapping right mapping
+           else
+             right in
+         Node (left, i, right)
+    in
+    let tree = List.fold_left insert_mapping tree mappings in
+    tree in
+
+  if debug then (
+    let printer ((sp, ep), mapping) =
+      sprintf "[%Lx-%Lx] " sp ep ^
+       match mapping with
+       | None -> "(none)"
+       | Some { start = start; size = size; order = order } ->
+           sprintf "%Lx..%Lx(%d)" start (start+^size-^1L) order
+    in
+    print_binary_tree printer printer tree
+  );
+
+  tree
+
+let add_mapping ({ mappings = mappings } as t) start size arr =
+  let order = List.length mappings in
+  let mapping = { start = start; size = size; arr = arr; order = order } in
+  let mappings = mapping :: mappings in
+  let tree = tree_of_mappings mappings in
+  { t with mappings = mappings; tree = tree }
+
+let add_file t fd addr =
   let size = (fstat fd).st_size in
   (* mmap(2) the file using Bigarray module. *)
   let arr = Array1.map_file fd char c_layout false size in
-  (* Create the mapping entry and keep the mappings sorted by start addr. *)
-  let mappings =
-    { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
-  let mappings = sort_mappings mappings in
-  { t with mappings = mappings }
+  (* Create the mapping entry. *)
+  add_mapping t addr (Int64.of_int size) arr
+
+let add_string ({ mappings = mappings } as t) str addr =
+  let size = String.length str in
+  (* Copy the string data to a Bigarray. *)
+  let arr = Array1.create char c_layout size in
+  for i = 0 to size-1 do
+    Array1.set arr i (String.unsafe_get str i)
+  done;
+  (* Create the mapping entry. *)
+  add_mapping t addr (Int64.of_int size) arr
 
 let of_file fd addr =
   let t = create () in
   add_file t fd addr
 
-(* Find in mappings and return first predicate match. *)
-let _find_map { mappings = mappings } pred =
-  let rec loop = function
-    | [] -> None
-    | m :: ms ->
-       match pred m with
-       | Some n -> Some n
-       | None -> loop ms
+let of_string str addr =
+  let t = create () in
+  add_string t str addr
+
+(* 'get_mapping' is the crucial, fast lookup function for address -> mapping.
+ * It searches the tree (hence fast) to work out the topmost mapping which
+ * applies to an address.
+ *
+ * Returns (rightend * mapping option)
+ * where 'mapping option' is the mapping (or None if it's a hole)
+ *   and 'rightend' is the next address at which there is a different
+ *       mapping/hole.  In other words, this mapping result is good for
+ *       addresses [addr .. rightend-1].
+ *)
+let rec get_mapping addr = function
+  | Leaf ((_, rightend), mapping) -> rightend, mapping
+
+  | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+         (_, None),
+         right) ->
+      let subrightend, submapping =
+       if addr < leftend then get_mapping addr left
+       else get_mapping addr right in
+      subrightend, submapping
+
+  | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+         (_, Some mapping),
+         right) ->
+      let subrightend, submapping =
+       if addr < leftend then get_mapping addr left
+       else get_mapping addr right in
+      (match submapping with
+       | None -> subrightend, Some mapping
+       | Some submapping ->
+          subrightend,
+          Some (if mapping.order > submapping.order then mapping
+                else submapping)
+      )
+
+(* Use the tree to quickly check if an address is mapped (returns false
+ * if it's a hole).
+ *)
+let is_mapped { mappings = mappings; tree = tree } addr =
+  (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *)
+  match mappings with
+  | [] -> false
+  | _ ->
+      let _, mapping = get_mapping addr tree in
+      mapping <> None
+
+(* Get a single byte. *)
+let get_byte { tree = tree } addr =
+  (* Get the mapping which applies to this address: *)
+  let _, mapping = get_mapping addr tree in
+  match mapping with
+  | Some { start = start; size = size; arr = arr } ->
+      let offset = Int64.to_int (addr -^ start) in
+      Char.code (Array1.get arr offset)
+  | None ->
+      invalid_arg "get_byte"
+
+(* Get a range of bytes, possibly across several intervals. *)
+let get_bytes { tree = tree } addr len =
+  let str = String.create len in
+
+  let rec loop addr pos len =
+    if len > 0 then (
+      let rightend, mapping = get_mapping addr tree in
+      match mapping with
+      | Some { start = start; size = size; arr = arr } ->
+         (* Offset within this mapping. *)
+         let offset = Int64.to_int (addr -^ start) in
+         (* Number of bytes to read before we either get to the end
+          * of our 'len' or we fall off the end of this interval.
+          *)
+         let n = min len (Int64.to_int (rightend -^ addr)) in
+         for i = 0 to n-1 do
+           String.unsafe_set str (pos + i) (Array1.get arr (offset + i))
+         done;
+         let len = len - n in
+         loop (addr +^ Int64.of_int n) (pos + n) len
+
+      | None ->
+         invalid_arg "get_bytes"
+    )
+  in
+  loop addr 0 len;
+
+  str
+
+let get_int32 t addr =
+  let e = get_endian t in
+  let str = get_bytes t addr 4 in
+  let bs = Bitstring.bitstring_of_string str in
+  bitmatch bs with
+  | { addr : 32 : endian (e) } -> addr
+  | { _ } -> invalid_arg "get_int32"
+
+let get_int64 t addr =
+  let e = get_endian t in
+  let str = get_bytes t addr 8 in
+  let bs = Bitstring.bitstring_of_string str in
+  bitmatch bs with
+  | { addr : 64 : endian (e) } -> addr
+  | { _ } -> invalid_arg "get_int64"
+
+let get_C_int = get_int32
+
+let get_C_long t addr =
+  let ws = get_wordsize t in
+  match ws with
+  | W32 -> Int64.of_int32 (get_int32 t addr)
+  | W64 -> get_int64 t addr
+
+(* Take bytes until a condition is not met.  This is efficient
+ * in that we stay within the same mapping as long as we can.
+ *
+ * If we hit a hole, raises Invalid_argument "dowhile".
+ *)
+let dowhile { tree = tree } addr cond =
+  let rec loop addr =
+    let rightend, mapping = get_mapping addr tree in
+    match mapping with
+    | Some { start = start; size = size; arr = arr } ->
+       (* Offset within this mapping. *)
+       let offset = Int64.to_int (addr -^ start) in
+       (* Number of bytes before we fall off the end of this interval. *)
+       let n = Int64.to_int (rightend -^ addr) in
+
+       let rec loop2 addr offset n =
+         if n > 0 then (
+           let c = Array1.get arr offset in
+           if cond addr c then
+             loop2 (addr +^ 1L) (offset + 1) (n - 1)
+           else
+             false (* stop now, finish outer loop too *)
+         )
+         else true (* fell off the end, so continue outer loop *)
+       in
+       if loop2 addr offset n then
+         loop (addr +^ Int64.of_int n)
+
+    | None ->
+       invalid_arg "dowhile"
   in
-  loop mappings
+  loop addr
+
+let is_mapped_range ({ mappings = mappings } as t) addr size =
+  match mappings with
+  (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *)
+  | [] -> false
+  | _ ->
+      (* Quick and dirty.  It's possible to make a much faster
+       * implementation of this which doesn't call the closure for every
+       * byte.
+       *)
+      let size = ref size in
+      try dowhile t addr (fun _ _ -> decr size; !size > 0); true
+      with Invalid_argument "dowhile" -> false
+
+(* Get a string, ending at ASCII NUL character. *)
+let get_string t addr =
+  let chars = ref [] in
+  try
+    dowhile t addr (
+      fun _ c ->
+       if c <> '\000' then (
+         chars := c :: !chars;
+         true
+       ) else false
+    );
+    let chars = List.rev !chars in
+    let len = List.length chars in
+    let str = String.create len in
+    let i = ref 0 in
+    List.iter (fun c -> String.unsafe_set str !i c; incr i) chars;
+    str
+  with
+    Invalid_argument _ -> invalid_arg "get_string"
+
+let is_string t addr =
+  try dowhile t addr (fun _ c -> c <> '\000'); true
+  with Invalid_argument _ -> false
+
+let is_C_identifier t addr =
+  let i = ref 0 in
+  let r = ref true in
+  try
+    dowhile t addr (
+      fun _ c ->
+       let b =
+         if !i = 0 then (
+           c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
+         ) else (
+           if c = '\000' then false
+           else (
+             if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
+               c >= '0' && c <= '9' then
+                 true
+             else (
+               r := false;
+               false
+             )
+           )
+         ) in
+       incr i;
+       b
+    );
+    !r
+  with
+    Invalid_argument _ -> false
 
+(* The following functions are actually written in C
+ * because memmem(3) is likely to be much faster than anything
+ * we could write in OCaml.
+ *
+ * Also OCaml bigarrays are specifically designed to be accessed
+ * easily from C:
+ *   http://caml.inria.fr/pub/docs/manual-ocaml/manual043.html
+ *)
+(*
 (* Array+offset = string? *)
 let string_at arr offset str strlen =
   let j = ref offset in
@@ -125,26 +524,30 @@ let _find_in start align str arr =
     loop ()
   )
   else Some start
+*)
+external _find_in :
+  int -> int -> string -> (char,int8_unsigned_elt,c_layout) Array1.t ->
+  int option = "virt_mem_mmap_find_in"
 
 (* Generic find function. *)
-let _find t start align str =
-  _find_map t (
-    fun { start = mstart; size = msize; arr = arr } ->
-      if mstart >= start then (
-       (* Check this mapping from the beginning. *)
-       match _find_in 0 align str arr with
-       | Some offset -> Some (mstart +^ Int64.of_int offset)
-       | None -> None
-      )
-      else if mstart < start && start <= mstart+^msize then (
-       (* Check this mapping from somewhere in the middle. *)
-       let offset = Int64.to_int (start -^ mstart) in
-       match _find_in offset align str arr with
-       | Some offset -> Some (mstart +^ Int64.of_int offset)
+let _find { tree = tree } start align str =
+  let rec loop addr =
+    let rightend, mapping = get_mapping addr tree in
+    match mapping with
+    | Some { start = start; size = size; arr = arr } ->
+       (* Offset within this mapping. *)
+       let offset = Int64.to_int (addr -^ start) in
+
+       (match _find_in offset align str arr with
        | None -> None
-      )
-      else None
-  )
+       | Some offset -> Some (start +^ Int64.of_int offset)
+       )
+
+    | None ->
+       (* Find functions all silently skip holes, so: *)
+       loop rightend
+  in
+  loop start
 
 let find t ?(start=0L) str =
   _find t start 1 str
@@ -181,9 +584,9 @@ and string_of_addr t addr =
   let bits = bits_of_wordsize (get_wordsize t) in
   let e = get_endian t in
   let bs = BITSTRING { addr : bits : endian (e) } in
-  Bitmatch.string_of_bitstring bs
+  Bitstring.string_of_bitstring bs
 *)
-(* XXX bitmatch is missing 'construct_int64_le_unsigned' so we
+(* XXX bitstring is missing 'construct_int64_le_unsigned' so we
  * have to force this to 32 bits for the moment.
  *)
 and string_of_addr t addr =
@@ -191,124 +594,14 @@ and string_of_addr t addr =
   assert (bits = 32);
   let e = get_endian t in
   let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
-  Bitmatch.string_of_bitstring bs
-
-and addr_of_string t str =
-  let bits = bits_of_wordsize (get_wordsize t) in
-  let e = get_endian t in
-  let bs = Bitmatch.bitstring_of_string str in
-  bitmatch bs with
-  | { addr : bits : endian (e) } -> addr
-  | { _ } -> invalid_arg "addr_of_string"
-
-let get_byte { mappings = mappings } addr =
-  let rec loop = function
-    | [] -> invalid_arg "get_byte"
-    | { start = start; size = size; arr = arr } :: _
-       when start <= addr && addr < size ->
-       let offset = Int64.to_int (addr -^ start) in
-       Char.code (Array1.get arr offset)
-    | _ :: ms -> loop ms
-  in
-  loop mappings
-
-(* Take bytes until a condition is not met.  This is efficient in that
- * we stay within the same mapping as long as we can.
- *)
-let dowhile { mappings = mappings } addr cond =
-  let rec get_next_mapping addr = function
-    | [] -> invalid_arg "dowhile"
-    | { start = start; size = size; arr = arr } :: _
-       when start <= addr && addr < start +^ size ->
-       let offset = Int64.to_int (addr -^ start) in
-       let len = Int64.to_int size - offset in
-       arr, offset, len
-    | _ :: ms -> get_next_mapping addr ms
-  in
-  let rec loop addr =
-    let arr, offset, len = get_next_mapping addr mappings in
-    let rec loop2 i =
-      if i < len then (
-       let c = Array1.get arr (offset+i) in
-       if cond c then loop2 (i+1)
-      ) else
-       loop (addr +^ Int64.of_int len)
-    in
-    loop2 0
-  in
-  loop addr
-
-let get_bytes t addr len =
-  let str = String.create len in
-  let i = ref 0 in
-  try
-    dowhile t addr (
-      fun c ->
-       str.[!i] <- c;
-       incr i;
-       !i < len
-    );
-    str
-  with
-    Invalid_argument _ -> invalid_arg "get_bytes"
-
-let get_string t addr =
-  let chars = ref [] in
-  try
-    dowhile t addr (
-      fun c ->
-       if c <> '\000' then (
-         chars := c :: !chars;
-         true
-       ) else false
-    );
-    let chars = List.rev !chars in
-    let len = List.length chars in
-    let str = String.create len in
-    let i = ref 0 in
-    List.iter (fun c -> str.[!i] <- c; incr i) chars;
-    str
-  with
-    Invalid_argument _ -> invalid_arg "get_string"
-
-let is_string t addr =
-  try dowhile t addr (fun c -> c <> '\000'); true
-  with Invalid_argument _ -> false
-
-let is_C_identifier t addr =
-  let i = ref 0 in
-  let r = ref true in
-  try
-    dowhile t addr (
-      fun c ->
-       let b =
-         if !i = 0 then (
-           c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
-         ) else (
-           if c = '\000' then false
-           else (
-             if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
-               c >= '0' && c <= '9' then
-                 true
-             else (
-               r := false;
-               false
-             )
-           )
-         ) in
-       incr i;
-       b
-    );
-    !r
-  with
-    Invalid_argument _ -> false
+  Bitstring.string_of_bitstring bs
 
 let follow_pointer t addr =
   let ws = get_wordsize t in
   let e = get_endian t in
   let bits = bits_of_wordsize ws in
   let str = get_bytes t addr (bytes_of_wordsize ws) in
-  let bs = Bitmatch.bitstring_of_string str in
+  let bs = Bitstring.bitstring_of_string str in
   bitmatch bs with
   | { addr : bits : endian (e) } -> addr
   | { _ } -> invalid_arg "follow_pointer"
@@ -320,3 +613,8 @@ let succ_long t addr =
 let pred_long t addr =
   let ws = get_wordsize t in
   addr -^ Int64.of_int (bytes_of_wordsize ws)
+
+let align t addr =
+  let ws = get_wordsize t in
+  let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
+  (addr +^ mask) &^ (Int64.lognot mask)