open Virt_mem_utils
-let debug = true
+let debug = false
(* An address. *)
type addr = int64
* NB! If mappings = [], ignore contents of this field. (This is
* enforced by the 'hm phantom type).
*)
- tree : (interval * mapping list, interval * mapping list) binary_tree;
+ tree : (interval * mapping option, interval * mapping option) binary_tree;
(* Word size, endianness.
* Phantom types enforce that these are set before being used.
let create () = {
mappings = [];
- tree = Leaf ((0L,0L),[]);
+ tree = Leaf ((0L,0L),None);
wordsize = W32;
endian = Bitstring.LittleEndian;
}
(* Build the segment tree from the list of mappings. This code
* is taken from virt-df. For an explanation of the process see:
* http://en.wikipedia.org/wiki/Segment_tree
+ *
+ * See also the 'get_mapping' function below which uses this tree
+ * to do fast lookups.
*)
let tree_of_mappings mappings =
(* Construct the list of distinct endpoints. *)
* empty list which will be used to store the mappings.
*)
let rec interval_tree = function
- | Leaf elint -> Leaf (elint, [])
+ | Leaf elint -> Leaf (elint, None)
| Node (left, (), right) ->
let left = interval_tree left in
let right = interval_tree right in
let (leftstart, _) = interval_of_node left in
let (_, rightend) = interval_of_node right in
let interval = leftstart, rightend in
- Node (left, (interval, []), right)
+ Node (left, (interval, None), right)
and interval_of_node = function
| Leaf (elint, _) -> elint
| Node (_, (interval, _), _) -> interval
match tree with
(* Test if we should insert into this leaf or node: *)
- | Leaf (interval, mappings) when interval <-< seginterval ->
- Leaf (interval, mapping :: mappings)
- | Node (left, (interval, mappings), right)
+ | Leaf (interval, None) when interval <-< seginterval ->
+ Leaf (interval, Some mapping)
+ | Leaf (interval, Some oldmapping) when interval <-< seginterval ->
+ let mapping =
+ if oldmapping.order > mapping.order then oldmapping else mapping in
+ Leaf (interval, Some mapping)
+
+ | Node (left, (interval, None), right) when interval <-< seginterval ->
+ Node (left, (interval, Some mapping), right)
+
+ | Node (left, (interval, Some oldmapping), right)
when interval <-< seginterval ->
- Node (left, (interval, mapping :: mappings), right)
+ let mapping =
+ if oldmapping.order > mapping.order then oldmapping else mapping in
+ Node (left, (interval, Some mapping), right)
| (Leaf _) as leaf -> leaf
tree in
if debug then (
- let printer ((sp, ep), mappings) =
+ let printer ((sp, ep), mapping) =
sprintf "[%Lx-%Lx] " sp ep ^
- String.concat ";"
- (List.map (fun { start = start; size = size } ->
- sprintf "%Lx+%Lx" start size)
- mappings)
+ match mapping with
+ | None -> "(none)"
+ | Some { start = start; size = size; order = order } ->
+ sprintf "%Lx..%Lx(%d)" start (start+^size-^1L) order
in
print_binary_tree printer printer tree
);
let t = create () in
add_string t str addr
-(* Look up an address and get the top-most mapping which contains it.
- * This uses the segment tree, so it's fast. The top-most mapping is
- * the one with the highest 'order' field.
+(* 'get_mapping' is the crucial, fast lookup function for address -> mapping.
+ * It searches the tree (hence fast) to work out the topmost mapping which
+ * applies to an address.
*
- * Warning: This 'hot' code was carefully optimized based on
- * feedback from 'gprof'. Avoid fiddling with it.
+ * Returns (rightend * mapping option)
+ * where 'mapping option' is the mapping (or None if it's a hole)
+ * and 'rightend' is the next address at which there is a different
+ * mapping/hole. In other words, this mapping result is good for
+ * addresses [addr .. rightend-1].
*)
let rec get_mapping addr = function
- | Leaf (_, []) -> None
- | Leaf (_, [mapping]) -> Some mapping
- | Leaf (_, mappings) -> Some (find_highest_order mappings)
+ | Leaf ((_, rightend), mapping) -> rightend, mapping
- (* Try to avoid expensive search if node mappings is empty: *)
| Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
- (_, []),
+ (_, None),
right) ->
- let submapping =
+ let subrightend, submapping =
if addr < leftend then get_mapping addr left
else get_mapping addr right in
- submapping
+ subrightend, submapping
- (* ... or a singleton: *)
| Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
- (_, [mapping]),
+ (_, Some mapping),
right) ->
- let submapping =
+ let subrightend, submapping =
if addr < leftend then get_mapping addr left
else get_mapping addr right in
(match submapping with
- | None -> Some mapping
+ | None -> subrightend, Some mapping
| Some submapping ->
+ subrightend,
Some (if mapping.order > submapping.order then mapping
else submapping)
)
- (* Normal recursive case: *)
- | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
- (_, mappings),
- right) ->
- let submapping =
- if addr < leftend then get_mapping addr left
- else get_mapping addr right in
- (match submapping with
- | None -> Some (find_highest_order mappings)
- | Some submapping -> Some (find_highest_order (submapping :: mappings))
- )
-
-and find_highest_order mappings =
- List.fold_left (
- fun mapping1 mapping2 ->
- if mapping1.order > mapping2.order then mapping1 else mapping2
- ) (List.hd mappings) (List.tl mappings)
+(* Use the tree to quickly check if an address is mapped (returns false
+ * if it's a hole).
+ *)
+let is_mapped { tree = tree } addr =
+ let _, mapping = get_mapping addr tree in
+ mapping <> None
(* Get a single byte. *)
let get_byte { tree = tree } addr =
(* Get the mapping which applies to this address: *)
- match get_mapping addr tree with
+ let _, mapping = get_mapping addr tree in
+ match mapping with
| Some { start = start; size = size; arr = arr } ->
let offset = Int64.to_int (addr -^ start) in
Char.code (Array1.get arr offset)
| None ->
invalid_arg "get_byte"
-(*
- let rec loop = function
- | [] -> invalid_arg "get_byte"
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size ->
- let offset = Int64.to_int (addr -^ start) in
- Char.code (Array1.get arr offset)
- | _ :: ms -> loop ms
+
+(* Get a range of bytes, possibly across several intervals. *)
+let get_bytes { tree = tree } addr len =
+ let str = String.create len in
+
+ let rec loop addr pos len =
+ if len > 0 then (
+ let rightend, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ (* Offset within this mapping. *)
+ let offset = Int64.to_int (addr -^ start) in
+ (* Number of bytes to read before we either get to the end
+ * of our 'len' or we fall off the end of this interval.
+ *)
+ let n = min len (Int64.to_int (rightend -^ addr)) in
+ for i = 0 to n-1 do
+ String.unsafe_set str (pos + i) (Array1.get arr (offset + i))
+ done;
+ let len = len - n in
+ loop (addr +^ Int64.of_int n) (pos + n) len
+
+ | None ->
+ invalid_arg "get_bytes"
+ )
in
- loop mappings
-*)
+ loop addr 0 len;
+ str
-(*
+let get_int32 t addr =
+ let e = get_endian t in
+ let str = get_bytes t addr 4 in
+ let bs = Bitstring.bitstring_of_string str in
+ bitmatch bs with
+ | { addr : 32 : endian (e) } -> addr
+ | { _ } -> invalid_arg "get_int32"
-(* Find in mappings and return first predicate match. *)
-let _find_map { mappings = mappings } pred =
- let rec loop = function
- | [] -> None
- | m :: ms ->
- match pred m with
- | Some n -> Some n
- | None -> loop ms
+let get_int64 t addr =
+ let e = get_endian t in
+ let str = get_bytes t addr 8 in
+ let bs = Bitstring.bitstring_of_string str in
+ bitmatch bs with
+ | { addr : 64 : endian (e) } -> addr
+ | { _ } -> invalid_arg "get_int64"
+
+let get_C_int = get_int32
+
+let get_C_long t addr =
+ let ws = get_wordsize t in
+ match ws with
+ | W32 -> Int64.of_int32 (get_int32 t addr)
+ | W64 -> get_int64 t addr
+
+(* Take bytes until a condition is not met. This is efficient
+ * in that we stay within the same mapping as long as we can.
+ *
+ * If we hit a hole, raises Invalid_argument "dowhile".
+ *)
+let dowhile { tree = tree } addr cond =
+ let rec loop addr =
+ let rightend, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ (* Offset within this mapping. *)
+ let offset = Int64.to_int (addr -^ start) in
+ (* Number of bytes before we fall off the end of this interval. *)
+ let n = Int64.to_int (rightend -^ addr) in
+
+ let rec loop2 addr offset n =
+ if n > 0 then (
+ let c = Array1.get arr offset in
+ if cond addr c then
+ loop2 (addr +^ 1L) (offset + 1) (n - 1)
+ else
+ false (* stop now, finish outer loop too *)
+ )
+ else true (* fell off the end, so continue outer loop *)
+ in
+ if loop2 addr offset n then
+ loop (addr +^ Int64.of_int n)
+
+ | None ->
+ invalid_arg "dowhile"
in
- loop mappings
+ loop addr
+
+(* Get a string, ending at ASCII NUL character. *)
+let get_string t addr =
+ let chars = ref [] in
+ try
+ dowhile t addr (
+ fun _ c ->
+ if c <> '\000' then (
+ chars := c :: !chars;
+ true
+ ) else false
+ );
+ let chars = List.rev !chars in
+ let len = List.length chars in
+ let str = String.create len in
+ let i = ref 0 in
+ List.iter (fun c -> String.unsafe_set str !i c; incr i) chars;
+ str
+ with
+ Invalid_argument _ -> invalid_arg "get_string"
+
+let is_string t addr =
+ try dowhile t addr (fun _ c -> c <> '\000'); true
+ with Invalid_argument _ -> false
+
+let is_C_identifier t addr =
+ let i = ref 0 in
+ let r = ref true in
+ try
+ dowhile t addr (
+ fun _ c ->
+ let b =
+ if !i = 0 then (
+ c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
+ ) else (
+ if c = '\000' then false
+ else (
+ if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
+ c >= '0' && c <= '9' then
+ true
+ else (
+ r := false;
+ false
+ )
+ )
+ ) in
+ incr i;
+ b
+ );
+ !r
+ with
+ Invalid_argument _ -> false
(* The following functions are actually written in C
* because memmem(3) is likely to be much faster than anything
int option = "virt_mem_mmap_find_in"
(* Generic find function. *)
-let _find t start align str =
- _find_map t (
- fun { start = mstart; size = msize; arr = arr } ->
- if mstart >= start then (
- (* Check this mapping from the beginning. *)
- match _find_in 0 align str arr with
- | Some offset -> Some (mstart +^ Int64.of_int offset)
- | None -> None
- )
- else if mstart < start && start <= mstart+^msize then (
- (* Check this mapping from somewhere in the middle. *)
- let offset = Int64.to_int (start -^ mstart) in
- match _find_in offset align str arr with
- | Some offset -> Some (mstart +^ Int64.of_int offset)
+let _find { tree = tree } start align str =
+ let rec loop addr =
+ let rightend, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ (* Offset within this mapping. *)
+ let offset = Int64.to_int (addr -^ start) in
+
+ (match _find_in offset align str arr with
| None -> None
- )
- else None
- )
+ | Some offset -> Some (start +^ Int64.of_int offset)
+ )
+
+ | None ->
+ (* Find functions all silently skip holes, so: *)
+ loop rightend
+ in
+ loop start
let find t ?(start=0L) str =
_find t start 1 str
let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
Bitstring.string_of_bitstring bs
-and addr_of_string t str =
- let bits = bits_of_wordsize (get_wordsize t) in
- let e = get_endian t in
- let bs = Bitstring.bitstring_of_string str in
- bitmatch bs with
- | { addr : bits : endian (e) } -> addr
- | { _ } -> invalid_arg "addr_of_string"
-
-(* Take bytes until a condition is not met. This is efficient in that
- * we stay within the same mapping as long as we can.
- *)
-let dowhile { mappings = mappings } addr cond =
- let rec get_next_mapping addr = function
- | [] -> invalid_arg "dowhile"
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size ->
- let offset = Int64.to_int (addr -^ start) in
- let len = Int64.to_int size - offset in
- arr, offset, len
- | _ :: ms -> get_next_mapping addr ms
- in
- let rec loop addr =
- let arr, offset, len = get_next_mapping addr mappings in
- let rec loop2 i =
- if i < len then (
- let c = Array1.get arr (offset+i) in
- if cond c then loop2 (i+1)
- ) else
- loop (addr +^ Int64.of_int len)
- in
- loop2 0
- in
- loop addr
-
-let get_bytes t addr len =
- let str = String.create len in
- let i = ref 0 in
- try
- dowhile t addr (
- fun c ->
- str.[!i] <- c;
- incr i;
- !i < len
- );
- str
- with
- Invalid_argument _ -> invalid_arg "get_bytes"
-
-let get_int32 t addr =
- let e = get_endian t in
- let str = get_bytes t addr 4 in
- let bs = Bitstring.bitstring_of_string str in
- bitmatch bs with
- | { addr : 32 : endian (e) } -> addr
- | { _ } -> invalid_arg "follow_pointer"
-
-let get_int64 t addr =
- let e = get_endian t in
- let str = get_bytes t addr 8 in
- let bs = Bitstring.bitstring_of_string str in
- bitmatch bs with
- | { addr : 64 : endian (e) } -> addr
- | { _ } -> invalid_arg "follow_pointer"
-
-let get_C_int = get_int32
-
-let get_C_long t addr =
- let ws = get_wordsize t in
- match ws with
- | W32 -> Int64.of_int32 (get_int32 t addr)
- | W64 -> get_int64 t addr
-
-let get_string t addr =
- let chars = ref [] in
- try
- dowhile t addr (
- fun c ->
- if c <> '\000' then (
- chars := c :: !chars;
- true
- ) else false
- );
- let chars = List.rev !chars in
- let len = List.length chars in
- let str = String.create len in
- let i = ref 0 in
- List.iter (fun c -> str.[!i] <- c; incr i) chars;
- str
- with
- Invalid_argument _ -> invalid_arg "get_string"
-
-let is_string t addr =
- try dowhile t addr (fun c -> c <> '\000'); true
- with Invalid_argument _ -> false
-
-let is_C_identifier t addr =
- let i = ref 0 in
- let r = ref true in
- try
- dowhile t addr (
- fun c ->
- let b =
- if !i = 0 then (
- c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
- ) else (
- if c = '\000' then false
- else (
- if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
- c >= '0' && c <= '9' then
- true
- else (
- r := false;
- false
- )
- )
- ) in
- incr i;
- b
- );
- !r
- with
- Invalid_argument _ -> false
-
-let is_mapped { mappings = mappings } addr =
- let rec loop = function
- | [] -> false
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size -> true
- | _ :: ms -> loop ms
- in
- loop mappings
-
let follow_pointer t addr =
let ws = get_wordsize t in
let e = get_endian t in
let ws = get_wordsize t in
let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
(addr +^ mask) &^ (Int64.lognot mask)
-
-let map { mappings = mappings } f =
- List.map (fun { start = start; size = size } -> f start size) mappings
-
-let iter t f =
- ignore (map t (fun start size -> let () = f start size in ()))
-
-let nr_mappings { mappings = mappings } = List.length mappings
-*)