Extracted kernel structures for device addressing in ifconfig.
[virt-mem.git] / lib / virt_mem_mmap.ml
index b9013de..c323c1a 100644 (file)
@@ -26,7 +26,7 @@ open Bigarray
 
 open Virt_mem_utils
 
-let debug = true
+let debug = false
 
 (* An address. *)
 type addr = int64
@@ -58,7 +58,7 @@ type ('ws,'e,'hm) t = {
    * NB! If mappings = [], ignore contents of this field.  (This is
    * enforced by the 'hm phantom type).
    *)
-  tree : (interval * mapping list, interval * mapping list) binary_tree;
+  tree : (interval * mapping option, interval * mapping option) binary_tree;
 
   (* Word size, endianness.
    * Phantom types enforce that these are set before being used.
@@ -69,7 +69,7 @@ type ('ws,'e,'hm) t = {
 
 let create () = {
   mappings = [];
-  tree = Leaf ((0L,0L),[]);
+  tree = Leaf ((0L,0L),None);
   wordsize = W32;
   endian = Bitstring.LittleEndian;
 }
@@ -85,6 +85,9 @@ let get_endian t = t.endian
 (* Build the segment tree from the list of mappings.  This code
  * is taken from virt-df.  For an explanation of the process see:
  * http://en.wikipedia.org/wiki/Segment_tree
+ *
+ * See also the 'get_mapping' function below which uses this tree
+ * to do fast lookups.
  *)
 let tree_of_mappings mappings =
   (* Construct the list of distinct endpoints. *)
@@ -151,14 +154,14 @@ let tree_of_mappings mappings =
      * empty list which will be used to store the mappings.
      *)
     let rec interval_tree = function
-      | Leaf elint -> Leaf (elint, [])
+      | Leaf elint -> Leaf (elint, None)
       | Node (left, (), right) ->
          let left = interval_tree left in
          let right = interval_tree right in
          let (leftstart, _) = interval_of_node left in
          let (_, rightend) = interval_of_node right in
          let interval = leftstart, rightend in
-         Node (left, (interval, []), right)
+         Node (left, (interval, None), right)
     and interval_of_node = function
       | Leaf (elint, _) -> elint
       | Node (_, (interval, _), _) -> interval
@@ -188,11 +191,21 @@ let tree_of_mappings mappings =
 
       match tree with
       (* Test if we should insert into this leaf or node: *)
-      | Leaf (interval, mappings) when interval <-< seginterval ->
-         Leaf (interval, mapping :: mappings)
-      | Node (left, (interval, mappings), right)
+      | Leaf (interval, None) when interval <-< seginterval ->
+         Leaf (interval, Some mapping)
+      | Leaf (interval, Some oldmapping) when interval <-< seginterval ->
+         let mapping =
+           if oldmapping.order > mapping.order then oldmapping else mapping in
+         Leaf (interval, Some mapping)
+
+      | Node (left, (interval, None), right) when interval <-< seginterval ->
+         Node (left, (interval, Some mapping), right)
+
+      | Node (left, (interval, Some oldmapping), right)
          when interval <-< seginterval ->
-         Node (left, (interval, mapping :: mappings), right)
+         let mapping =
+           if oldmapping.order > mapping.order then oldmapping else mapping in
+         Node (left, (interval, Some mapping), right)
 
       | (Leaf _) as leaf -> leaf
 
@@ -214,12 +227,12 @@ let tree_of_mappings mappings =
     tree in
 
   if debug then (
-    let printer ((sp, ep), mappings) =
+    let printer ((sp, ep), mapping) =
       sprintf "[%Lx-%Lx] " sp ep ^
-       String.concat ";"
-       (List.map (fun { start = start; size = size } ->
-                    sprintf "%Lx+%Lx" start size)
-          mappings)
+       match mapping with
+       | None -> "(none)"
+       | Some { start = start; size = size; order = order } ->
+           sprintf "%Lx..%Lx(%d)" start (start+^size-^1L) order
     in
     print_binary_tree printer printer tree
   );
@@ -258,93 +271,213 @@ let of_string str addr =
   let t = create () in
   add_string t str addr
 
-(* Look up an address and get the top-most mapping which contains it.
- * This uses the segment tree, so it's fast.  The top-most mapping is
- * the one with the highest 'order' field.
+(* 'get_mapping' is the crucial, fast lookup function for address -> mapping.
+ * It searches the tree (hence fast) to work out the topmost mapping which
+ * applies to an address.
  *
- * Warning: This 'hot' code was carefully optimized based on
- * feedback from 'gprof'.  Avoid fiddling with it.
+ * Returns (rightend * mapping option)
+ * where 'mapping option' is the mapping (or None if it's a hole)
+ *   and 'rightend' is the next address at which there is a different
+ *       mapping/hole.  In other words, this mapping result is good for
+ *       addresses [addr .. rightend-1].
  *)
 let rec get_mapping addr = function
-  | Leaf (_, []) -> None
-  | Leaf (_, [mapping]) -> Some mapping
-  | Leaf (_, mappings) -> Some (find_highest_order mappings)
+  | Leaf ((_, rightend), mapping) -> rightend, mapping
 
-  (* Try to avoid expensive search if node mappings is empty: *)
   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
-         (_, []),
+         (_, None),
          right) ->
-      let submapping =
+      let subrightend, submapping =
        if addr < leftend then get_mapping addr left
        else get_mapping addr right in
-      submapping
+      subrightend, submapping
 
-  (* ... or a singleton: *)
   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
-         (_, [mapping]),
+         (_, Some mapping),
          right) ->
-      let submapping =
+      let subrightend, submapping =
        if addr < leftend then get_mapping addr left
        else get_mapping addr right in
       (match submapping with
-       | None -> Some mapping
+       | None -> subrightend, Some mapping
        | Some submapping ->
+          subrightend,
           Some (if mapping.order > submapping.order then mapping
                 else submapping)
       )
 
-  (* Normal recursive case: *)
-  | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
-         (_, mappings),
-         right) ->
-      let submapping =
-       if addr < leftend then get_mapping addr left
-       else get_mapping addr right in
-      (match submapping with
-       | None -> Some (find_highest_order mappings)
-       | Some submapping -> Some (find_highest_order (submapping :: mappings))
-      )
-
-and find_highest_order mappings =
-  List.fold_left (
-    fun mapping1 mapping2 ->
-      if mapping1.order > mapping2.order then mapping1 else mapping2
-  ) (List.hd mappings) (List.tl mappings)
+(* Use the tree to quickly check if an address is mapped (returns false
+ * if it's a hole).
+ *)
+let is_mapped { mappings = mappings; tree = tree } addr =
+  (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *)
+  match mappings with
+  | [] -> false
+  | _ ->
+      let _, mapping = get_mapping addr tree in
+      mapping <> None
 
 (* Get a single byte. *)
 let get_byte { tree = tree } addr =
   (* Get the mapping which applies to this address: *)
-  match get_mapping addr tree with
+  let _, mapping = get_mapping addr tree in
+  match mapping with
   | Some { start = start; size = size; arr = arr } ->
       let offset = Int64.to_int (addr -^ start) in
       Char.code (Array1.get arr offset)
   | None ->
       invalid_arg "get_byte"
-(*
-  let rec loop = function
-    | [] -> invalid_arg "get_byte"
-    | { start = start; size = size; arr = arr } :: _
-       when start <= addr && addr < start +^ size ->
-       let offset = Int64.to_int (addr -^ start) in
-       Char.code (Array1.get arr offset)
-    | _ :: ms -> loop ms
+
+(* Get a range of bytes, possibly across several intervals. *)
+let get_bytes { tree = tree } addr len =
+  let str = String.create len in
+
+  let rec loop addr pos len =
+    if len > 0 then (
+      let rightend, mapping = get_mapping addr tree in
+      match mapping with
+      | Some { start = start; size = size; arr = arr } ->
+         (* Offset within this mapping. *)
+         let offset = Int64.to_int (addr -^ start) in
+         (* Number of bytes to read before we either get to the end
+          * of our 'len' or we fall off the end of this interval.
+          *)
+         let n = min len (Int64.to_int (rightend -^ addr)) in
+         for i = 0 to n-1 do
+           String.unsafe_set str (pos + i) (Array1.get arr (offset + i))
+         done;
+         let len = len - n in
+         loop (addr +^ Int64.of_int n) (pos + n) len
+
+      | None ->
+         invalid_arg "get_bytes"
+    )
   in
-  loop mappings
-*)
+  loop addr 0 len;
 
+  str
 
-(*
+let get_int32 t addr =
+  let e = get_endian t in
+  let str = get_bytes t addr 4 in
+  let bs = Bitstring.bitstring_of_string str in
+  bitmatch bs with
+  | { addr : 32 : endian (e) } -> addr
+  | { _ } -> invalid_arg "get_int32"
 
-(* Find in mappings and return first predicate match. *)
-let _find_map { mappings = mappings } pred =
-  let rec loop = function
-    | [] -> None
-    | m :: ms ->
-       match pred m with
-       | Some n -> Some n
-       | None -> loop ms
+let get_int64 t addr =
+  let e = get_endian t in
+  let str = get_bytes t addr 8 in
+  let bs = Bitstring.bitstring_of_string str in
+  bitmatch bs with
+  | { addr : 64 : endian (e) } -> addr
+  | { _ } -> invalid_arg "get_int64"
+
+let get_C_int = get_int32
+
+let get_C_long t addr =
+  let ws = get_wordsize t in
+  match ws with
+  | W32 -> Int64.of_int32 (get_int32 t addr)
+  | W64 -> get_int64 t addr
+
+(* Take bytes until a condition is not met.  This is efficient
+ * in that we stay within the same mapping as long as we can.
+ *
+ * If we hit a hole, raises Invalid_argument "dowhile".
+ *)
+let dowhile { tree = tree } addr cond =
+  let rec loop addr =
+    let rightend, mapping = get_mapping addr tree in
+    match mapping with
+    | Some { start = start; size = size; arr = arr } ->
+       (* Offset within this mapping. *)
+       let offset = Int64.to_int (addr -^ start) in
+       (* Number of bytes before we fall off the end of this interval. *)
+       let n = Int64.to_int (rightend -^ addr) in
+
+       let rec loop2 addr offset n =
+         if n > 0 then (
+           let c = Array1.get arr offset in
+           if cond addr c then
+             loop2 (addr +^ 1L) (offset + 1) (n - 1)
+           else
+             false (* stop now, finish outer loop too *)
+         )
+         else true (* fell off the end, so continue outer loop *)
+       in
+       if loop2 addr offset n then
+         loop (addr +^ Int64.of_int n)
+
+    | None ->
+       invalid_arg "dowhile"
   in
-  loop mappings
+  loop addr
+
+let is_mapped_range ({ mappings = mappings } as t) addr size =
+  match mappings with
+  (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *)
+  | [] -> false
+  | _ ->
+      (* Quick and dirty.  It's possible to make a much faster
+       * implementation of this which doesn't call the closure for every
+       * byte.
+       *)
+      let size = ref size in
+      try dowhile t addr (fun _ _ -> decr size; !size > 0); true
+      with Invalid_argument "dowhile" -> false
+
+(* Get a string, ending at ASCII NUL character. *)
+let get_string t addr =
+  let chars = ref [] in
+  try
+    dowhile t addr (
+      fun _ c ->
+       if c <> '\000' then (
+         chars := c :: !chars;
+         true
+       ) else false
+    );
+    let chars = List.rev !chars in
+    let len = List.length chars in
+    let str = String.create len in
+    let i = ref 0 in
+    List.iter (fun c -> String.unsafe_set str !i c; incr i) chars;
+    str
+  with
+    Invalid_argument _ -> invalid_arg "get_string"
+
+let is_string t addr =
+  try dowhile t addr (fun _ c -> c <> '\000'); true
+  with Invalid_argument _ -> false
+
+let is_C_identifier t addr =
+  let i = ref 0 in
+  let r = ref true in
+  try
+    dowhile t addr (
+      fun _ c ->
+       let b =
+         if !i = 0 then (
+           c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
+         ) else (
+           if c = '\000' then false
+           else (
+             if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
+               c >= '0' && c <= '9' then
+                 true
+             else (
+               r := false;
+               false
+             )
+           )
+         ) in
+       incr i;
+       b
+    );
+    !r
+  with
+    Invalid_argument _ -> false
 
 (* The following functions are actually written in C
  * because memmem(3) is likely to be much faster than anything
@@ -397,24 +530,24 @@ external _find_in :
   int option = "virt_mem_mmap_find_in"
 
 (* Generic find function. *)
-let _find t start align str =
-  _find_map t (
-    fun { start = mstart; size = msize; arr = arr } ->
-      if mstart >= start then (
-       (* Check this mapping from the beginning. *)
-       match _find_in 0 align str arr with
-       | Some offset -> Some (mstart +^ Int64.of_int offset)
-       | None -> None
-      )
-      else if mstart < start && start <= mstart+^msize then (
-       (* Check this mapping from somewhere in the middle. *)
-       let offset = Int64.to_int (start -^ mstart) in
-       match _find_in offset align str arr with
-       | Some offset -> Some (mstart +^ Int64.of_int offset)
+let _find { tree = tree } start align str =
+  let rec loop addr =
+    let rightend, mapping = get_mapping addr tree in
+    match mapping with
+    | Some { start = start; size = size; arr = arr } ->
+       (* Offset within this mapping. *)
+       let offset = Int64.to_int (addr -^ start) in
+
+       (match _find_in offset align str arr with
        | None -> None
-      )
-      else None
-  )
+       | Some offset -> Some (start +^ Int64.of_int offset)
+       )
+
+    | None ->
+       (* Find functions all silently skip holes, so: *)
+       loop rightend
+  in
+  loop start
 
 let find t ?(start=0L) str =
   _find t start 1 str
@@ -463,138 +596,6 @@ and string_of_addr t addr =
   let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
   Bitstring.string_of_bitstring bs
 
-and addr_of_string t str =
-  let bits = bits_of_wordsize (get_wordsize t) in
-  let e = get_endian t in
-  let bs = Bitstring.bitstring_of_string str in
-  bitmatch bs with
-  | { addr : bits : endian (e) } -> addr
-  | { _ } -> invalid_arg "addr_of_string"
-
-(* Take bytes until a condition is not met.  This is efficient in that
- * we stay within the same mapping as long as we can.
- *)
-let dowhile { mappings = mappings } addr cond =
-  let rec get_next_mapping addr = function
-    | [] -> invalid_arg "dowhile"
-    | { start = start; size = size; arr = arr } :: _
-       when start <= addr && addr < start +^ size ->
-       let offset = Int64.to_int (addr -^ start) in
-       let len = Int64.to_int size - offset in
-       arr, offset, len
-    | _ :: ms -> get_next_mapping addr ms
-  in
-  let rec loop addr =
-    let arr, offset, len = get_next_mapping addr mappings in
-    let rec loop2 i =
-      if i < len then (
-       let c = Array1.get arr (offset+i) in
-       if cond c then loop2 (i+1)
-      ) else
-       loop (addr +^ Int64.of_int len)
-    in
-    loop2 0
-  in
-  loop addr
-
-let get_bytes t addr len =
-  let str = String.create len in
-  let i = ref 0 in
-  try
-    dowhile t addr (
-      fun c ->
-       str.[!i] <- c;
-       incr i;
-       !i < len
-    );
-    str
-  with
-    Invalid_argument _ -> invalid_arg "get_bytes"
-
-let get_int32 t addr =
-  let e = get_endian t in
-  let str = get_bytes t addr 4 in
-  let bs = Bitstring.bitstring_of_string str in
-  bitmatch bs with
-  | { addr : 32 : endian (e) } -> addr
-  | { _ } -> invalid_arg "follow_pointer"
-
-let get_int64 t addr =
-  let e = get_endian t in
-  let str = get_bytes t addr 8 in
-  let bs = Bitstring.bitstring_of_string str in
-  bitmatch bs with
-  | { addr : 64 : endian (e) } -> addr
-  | { _ } -> invalid_arg "follow_pointer"
-
-let get_C_int = get_int32
-
-let get_C_long t addr =
-  let ws = get_wordsize t in
-  match ws with
-  | W32 -> Int64.of_int32 (get_int32 t addr)
-  | W64 -> get_int64 t addr
-
-let get_string t addr =
-  let chars = ref [] in
-  try
-    dowhile t addr (
-      fun c ->
-       if c <> '\000' then (
-         chars := c :: !chars;
-         true
-       ) else false
-    );
-    let chars = List.rev !chars in
-    let len = List.length chars in
-    let str = String.create len in
-    let i = ref 0 in
-    List.iter (fun c -> str.[!i] <- c; incr i) chars;
-    str
-  with
-    Invalid_argument _ -> invalid_arg "get_string"
-
-let is_string t addr =
-  try dowhile t addr (fun c -> c <> '\000'); true
-  with Invalid_argument _ -> false
-
-let is_C_identifier t addr =
-  let i = ref 0 in
-  let r = ref true in
-  try
-    dowhile t addr (
-      fun c ->
-       let b =
-         if !i = 0 then (
-           c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
-         ) else (
-           if c = '\000' then false
-           else (
-             if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
-               c >= '0' && c <= '9' then
-                 true
-             else (
-               r := false;
-               false
-             )
-           )
-         ) in
-       incr i;
-       b
-    );
-    !r
-  with
-    Invalid_argument _ -> false
-
-let is_mapped { mappings = mappings } addr =
-  let rec loop = function
-    | [] -> false
-    | { start = start; size = size; arr = arr } :: _
-       when start <= addr && addr < start +^ size -> true
-    | _ :: ms -> loop ms
-  in
-  loop mappings
-
 let follow_pointer t addr =
   let ws = get_wordsize t in
   let e = get_endian t in
@@ -617,12 +618,3 @@ let align t addr =
   let ws = get_wordsize t in
   let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
   (addr +^ mask) &^ (Int64.lognot mask)
-
-let map { mappings = mappings } f =
-  List.map (fun { start = start; size = size } -> f start size) mappings
-
-let iter t f =
-  ignore (map t (fun start size -> let () = f start size in ()))
-
-let nr_mappings { mappings = mappings } = List.length mappings
-*)