*)
open Unix
+open Printf
open Bigarray
open Virt_mem_utils
-(* Simple implementation at the moment: Store a list of mappings,
- * sorted by start address. We assume that mappings do not overlap.
- * We can change the implementation later if we need to. In most cases
- * there will only be a small number of mappings (probably 1).
- *)
-type ('a,'b) t = {
- mappings : mapping list;
- wordsize : wordsize option;
- endian : Bitmatch.endian option;
-}
-and mapping = {
+let debug = false
+
+(* An address. *)
+type addr = int64
+
+(* A range of addresses (start and start+size). *)
+type interval = addr * addr
+
+(* A mapping. *)
+type mapping = {
start : addr;
size : addr;
(* Bigarray mmap(2)'d region with byte addressing: *)
arr : (char,int8_unsigned_elt,c_layout) Array1.t;
+ (* The order that the mappings were added, 0 for the first mapping,
+ * 1 for the second mapping, etc.
+ *)
+ order : int;
}
-and addr = int64
+(* A memory map. *)
+type ('ws,'e,'hm) t = {
+ (* List of mappings, kept in reverse order they were added (new
+ * mappings are added at the head of this list).
+ *)
+ mappings : mapping list;
+
+ (* Segment tree for fast access to a mapping at a particular address.
+ * This is rebuilt each time a new mapping is added.
+ * NB! If mappings = [], ignore contents of this field. (This is
+ * enforced by the 'hm phantom type).
+ *)
+ tree : (interval * mapping option, interval * mapping option) binary_tree;
+
+ (* Word size, endianness.
+ * Phantom types enforce that these are set before being used.
+ *)
+ wordsize : wordsize;
+ endian : Bitstring.endian;
+}
let create () = {
mappings = [];
- wordsize = None;
- endian = None
+ tree = Leaf ((0L,0L),None);
+ wordsize = W32;
+ endian = Bitstring.LittleEndian;
}
-let set_wordsize t ws = { t with wordsize = Some ws }
+let set_wordsize t ws = { t with wordsize = ws }
+
+let set_endian t e = { t with endian = e }
+
+let get_wordsize t = t.wordsize
+
+let get_endian t = t.endian
+
+(* Build the segment tree from the list of mappings. This code
+ * is taken from virt-df. For an explanation of the process see:
+ * http://en.wikipedia.org/wiki/Segment_tree
+ *
+ * See also the 'get_mapping' function below which uses this tree
+ * to do fast lookups.
+ *)
+let tree_of_mappings mappings =
+ (* Construct the list of distinct endpoints. *)
+ let eps =
+ List.map
+ (fun { start = start; size = size } -> [start; start +^ size])
+ mappings in
+ let eps = sort_uniq (List.concat eps) in
+
+ (* Construct the elementary intervals. *)
+ let elints =
+ let elints, lastpoint =
+ List.fold_left (
+ fun (elints, prevpoint) point ->
+ ((point, point) :: (prevpoint, point) :: elints), point
+ ) ([], 0L) eps in
+ let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
+ List.rev elints in
+
+ if debug then (
+ eprintf "elementary intervals (%d in total):\n" (List.length elints);
+ List.iter (
+ fun (startpoint, endpoint) ->
+ eprintf " %Lx %Lx\n" startpoint endpoint
+ ) elints
+ );
+
+ (* Construct the binary tree of elementary intervals. *)
+ let tree =
+ (* Each elementary interval becomes a leaf. *)
+ let elints = List.map (fun elint -> Leaf elint) elints in
+ (* Recursively build this into a binary tree. *)
+ let rec make_layer = function
+ | [] -> []
+ | ([_] as x) -> x
+ (* Turn pairs of leaves at the bottom level into nodes. *)
+ | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
+ let xs = make_layer xs in
+ Node (a, (), b) :: xs
+ (* Turn pairs of nodes at higher levels into nodes. *)
+ | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
+ let xs = make_layer xs in
+ Node (left, (), right) :: xs
+ | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
+ in
+ let rec loop = function
+ | [] -> assert false
+ | [x] -> x
+ | xs -> loop (make_layer xs)
+ in
+ loop elints in
-let set_endian t e = { t with endian = Some e }
+ if debug then (
+ let leaf_printer (startpoint, endpoint) =
+ sprintf "%Lx-%Lx" startpoint endpoint
+ in
+ let node_printer () = "" in
+ print_binary_tree leaf_printer node_printer tree
+ );
+
+ (* Insert the mappings into the tree one by one. *)
+ let tree =
+ (* For each node/leaf in the tree, add its interval and an
+ * empty list which will be used to store the mappings.
+ *)
+ let rec interval_tree = function
+ | Leaf elint -> Leaf (elint, None)
+ | Node (left, (), right) ->
+ let left = interval_tree left in
+ let right = interval_tree right in
+ let (leftstart, _) = interval_of_node left in
+ let (_, rightend) = interval_of_node right in
+ let interval = leftstart, rightend in
+ Node (left, (interval, None), right)
+ and interval_of_node = function
+ | Leaf (elint, _) -> elint
+ | Node (_, (interval, _), _) -> interval
+ in
-let get_wordsize t = Option.get t.wordsize
+ let tree = interval_tree tree in
+ (* This should always be true: *)
+ assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
+
+ (* "Contained in" operator.
+ * 'a <-< b' iff 'a' is a subinterval of 'b'.
+ * |<---- a ---->|
+ * |<----------- b ----------->|
+ *)
+ let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
+
+ (* "Intersects" operator.
+ * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
+ * |<---- a ---->|
+ * |<----------- b ----------->|
+ *)
+ let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
+
+ let rec insert_mapping tree mapping =
+ let { start = start; size = size } = mapping in
+ let seginterval = start, start +^ size in
+
+ match tree with
+ (* Test if we should insert into this leaf or node: *)
+ | Leaf (interval, None) when interval <-< seginterval ->
+ Leaf (interval, Some mapping)
+ | Leaf (interval, Some oldmapping) when interval <-< seginterval ->
+ let mapping =
+ if oldmapping.order > mapping.order then oldmapping else mapping in
+ Leaf (interval, Some mapping)
+
+ | Node (left, (interval, None), right) when interval <-< seginterval ->
+ Node (left, (interval, Some mapping), right)
+
+ | Node (left, (interval, Some oldmapping), right)
+ when interval <-< seginterval ->
+ let mapping =
+ if oldmapping.order > mapping.order then oldmapping else mapping in
+ Node (left, (interval, Some mapping), right)
+
+ | (Leaf _) as leaf -> leaf
+
+ (* Else, should we insert into left or right subtrees? *)
+ | Node (left, i, right) ->
+ let left =
+ if seginterval /\ interval_of_node left then
+ insert_mapping left mapping
+ else
+ left in
+ let right =
+ if seginterval /\ interval_of_node right then
+ insert_mapping right mapping
+ else
+ right in
+ Node (left, i, right)
+ in
+ let tree = List.fold_left insert_mapping tree mappings in
+ tree in
+
+ if debug then (
+ let printer ((sp, ep), mapping) =
+ sprintf "[%Lx-%Lx] " sp ep ^
+ match mapping with
+ | None -> "(none)"
+ | Some { start = start; size = size; order = order } ->
+ sprintf "%Lx..%Lx(%d)" start (start+^size-^1L) order
+ in
+ print_binary_tree printer printer tree
+ );
-let get_endian t = Option.get t.endian
+ tree
-let sort_mappings mappings =
- let cmp { start = s1 } { start = s2 } = compare s1 s2 in
- List.sort cmp mappings
+let add_mapping ({ mappings = mappings } as t) start size arr =
+ let order = List.length mappings in
+ let mapping = { start = start; size = size; arr = arr; order = order } in
+ let mappings = mapping :: mappings in
+ let tree = tree_of_mappings mappings in
+ { t with mappings = mappings; tree = tree }
-let add_file ({ mappings = mappings } as t) fd addr =
- if addr &^ 7L <> 0L then
- invalid_arg "add_file: mapping address must be aligned to 8 bytes";
+let add_file t fd addr =
let size = (fstat fd).st_size in
(* mmap(2) the file using Bigarray module. *)
let arr = Array1.map_file fd char c_layout false size in
- (* Create the mapping entry and keep the mappings sorted by start addr. *)
- let mappings =
- { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
- let mappings = sort_mappings mappings in
- { t with mappings = mappings }
-
-let of_file fd addr =
- let t = create () in
- add_file t fd addr
+ (* Create the mapping entry. *)
+ add_mapping t addr (Int64.of_int size) arr
let add_string ({ mappings = mappings } as t) str addr =
- if addr &^ 7L <> 0L then
- invalid_arg "add_file: mapping address must be aligned to 8 bytes";
let size = String.length str in
(* Copy the string data to a Bigarray. *)
let arr = Array1.create char c_layout size in
for i = 0 to size-1 do
Array1.set arr i (String.unsafe_get str i)
done;
- (* Create the mapping entry and keep the mappings sorted by start addr. *)
- let mappings =
- { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
- let mappings = sort_mappings mappings in
- { t with mappings = mappings }
+ (* Create the mapping entry. *)
+ add_mapping t addr (Int64.of_int size) arr
+
+let of_file fd addr =
+ let t = create () in
+ add_file t fd addr
let of_string str addr =
let t = create () in
add_string t str addr
-(* Find in mappings and return first predicate match. *)
-let _find_map { mappings = mappings } pred =
- let rec loop = function
- | [] -> None
- | m :: ms ->
- match pred m with
- | Some n -> Some n
- | None -> loop ms
+(* 'get_mapping' is the crucial, fast lookup function for address -> mapping.
+ * It searches the tree (hence fast) to work out the topmost mapping which
+ * applies to an address.
+ *
+ * Returns (rightend * mapping option)
+ * where 'mapping option' is the mapping (or None if it's a hole)
+ * and 'rightend' is the next address at which there is a different
+ * mapping/hole. In other words, this mapping result is good for
+ * addresses [addr .. rightend-1].
+ *)
+let rec get_mapping addr = function
+ | Leaf ((_, rightend), mapping) -> rightend, mapping
+
+ | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+ (_, None),
+ right) ->
+ let subrightend, submapping =
+ if addr < leftend then get_mapping addr left
+ else get_mapping addr right in
+ subrightend, submapping
+
+ | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+ (_, Some mapping),
+ right) ->
+ let subrightend, submapping =
+ if addr < leftend then get_mapping addr left
+ else get_mapping addr right in
+ (match submapping with
+ | None -> subrightend, Some mapping
+ | Some submapping ->
+ subrightend,
+ Some (if mapping.order > submapping.order then mapping
+ else submapping)
+ )
+
+(* Use the tree to quickly check if an address is mapped (returns false
+ * if it's a hole).
+ *)
+let is_mapped { mappings = mappings; tree = tree } addr =
+ (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *)
+ match mappings with
+ | [] -> false
+ | _ ->
+ let _, mapping = get_mapping addr tree in
+ mapping <> None
+
+(* Get a single byte. *)
+let get_byte { tree = tree } addr =
+ (* Get the mapping which applies to this address: *)
+ let _, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ let offset = Int64.to_int (addr -^ start) in
+ Char.code (Array1.get arr offset)
+ | None ->
+ invalid_arg "get_byte"
+
+(* Get a range of bytes, possibly across several intervals. *)
+let get_bytes { tree = tree } addr len =
+ let str = String.create len in
+
+ let rec loop addr pos len =
+ if len > 0 then (
+ let rightend, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ (* Offset within this mapping. *)
+ let offset = Int64.to_int (addr -^ start) in
+ (* Number of bytes to read before we either get to the end
+ * of our 'len' or we fall off the end of this interval.
+ *)
+ let n = min len (Int64.to_int (rightend -^ addr)) in
+ for i = 0 to n-1 do
+ String.unsafe_set str (pos + i) (Array1.get arr (offset + i))
+ done;
+ let len = len - n in
+ loop (addr +^ Int64.of_int n) (pos + n) len
+
+ | None ->
+ invalid_arg "get_bytes"
+ )
in
- loop mappings
+ loop addr 0 len;
+
+ str
+
+let get_int32 t addr =
+ let e = get_endian t in
+ let str = get_bytes t addr 4 in
+ let bs = Bitstring.bitstring_of_string str in
+ bitmatch bs with
+ | { addr : 32 : endian (e) } -> addr
+ | { _ } -> invalid_arg "get_int32"
+
+let get_int64 t addr =
+ let e = get_endian t in
+ let str = get_bytes t addr 8 in
+ let bs = Bitstring.bitstring_of_string str in
+ bitmatch bs with
+ | { addr : 64 : endian (e) } -> addr
+ | { _ } -> invalid_arg "get_int64"
+
+let get_C_int = get_int32
+
+let get_C_long t addr =
+ let ws = get_wordsize t in
+ match ws with
+ | W32 -> Int64.of_int32 (get_int32 t addr)
+ | W64 -> get_int64 t addr
+
+(* Take bytes until a condition is not met. This is efficient
+ * in that we stay within the same mapping as long as we can.
+ *
+ * If we hit a hole, raises Invalid_argument "dowhile".
+ *)
+let dowhile { tree = tree } addr cond =
+ let rec loop addr =
+ let rightend, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ (* Offset within this mapping. *)
+ let offset = Int64.to_int (addr -^ start) in
+ (* Number of bytes before we fall off the end of this interval. *)
+ let n = Int64.to_int (rightend -^ addr) in
+
+ let rec loop2 addr offset n =
+ if n > 0 then (
+ let c = Array1.get arr offset in
+ if cond addr c then
+ loop2 (addr +^ 1L) (offset + 1) (n - 1)
+ else
+ false (* stop now, finish outer loop too *)
+ )
+ else true (* fell off the end, so continue outer loop *)
+ in
+ if loop2 addr offset n then
+ loop (addr +^ Int64.of_int n)
+
+ | None ->
+ invalid_arg "dowhile"
+ in
+ loop addr
+
+let is_mapped_range ({ mappings = mappings } as t) addr size =
+ match mappings with
+ (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *)
+ | [] -> false
+ | _ ->
+ (* Quick and dirty. It's possible to make a much faster
+ * implementation of this which doesn't call the closure for every
+ * byte.
+ *)
+ let size = ref size in
+ try dowhile t addr (fun _ _ -> decr size; !size > 0); true
+ with Invalid_argument "dowhile" -> false
+
+(* Get a string, ending at ASCII NUL character. *)
+let get_string t addr =
+ let chars = ref [] in
+ try
+ dowhile t addr (
+ fun _ c ->
+ if c <> '\000' then (
+ chars := c :: !chars;
+ true
+ ) else false
+ );
+ let chars = List.rev !chars in
+ let len = List.length chars in
+ let str = String.create len in
+ let i = ref 0 in
+ List.iter (fun c -> String.unsafe_set str !i c; incr i) chars;
+ str
+ with
+ Invalid_argument _ -> invalid_arg "get_string"
+
+let is_string t addr =
+ try dowhile t addr (fun _ c -> c <> '\000'); true
+ with Invalid_argument _ -> false
+
+let is_C_identifier t addr =
+ let i = ref 0 in
+ let r = ref true in
+ try
+ dowhile t addr (
+ fun _ c ->
+ let b =
+ if !i = 0 then (
+ c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
+ ) else (
+ if c = '\000' then false
+ else (
+ if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
+ c >= '0' && c <= '9' then
+ true
+ else (
+ r := false;
+ false
+ )
+ )
+ ) in
+ incr i;
+ b
+ );
+ !r
+ with
+ Invalid_argument _ -> false
(* The following functions are actually written in C
* because memmem(3) is likely to be much faster than anything
int option = "virt_mem_mmap_find_in"
(* Generic find function. *)
-let _find t start align str =
- _find_map t (
- fun { start = mstart; size = msize; arr = arr } ->
- if mstart >= start then (
- (* Check this mapping from the beginning. *)
- match _find_in 0 align str arr with
- | Some offset -> Some (mstart +^ Int64.of_int offset)
- | None -> None
- )
- else if mstart < start && start <= mstart+^msize then (
- (* Check this mapping from somewhere in the middle. *)
- let offset = Int64.to_int (start -^ mstart) in
- match _find_in offset align str arr with
- | Some offset -> Some (mstart +^ Int64.of_int offset)
+let _find { tree = tree } start align str =
+ let rec loop addr =
+ let rightend, mapping = get_mapping addr tree in
+ match mapping with
+ | Some { start = start; size = size; arr = arr } ->
+ (* Offset within this mapping. *)
+ let offset = Int64.to_int (addr -^ start) in
+
+ (match _find_in offset align str arr with
| None -> None
- )
- else None
- )
+ | Some offset -> Some (start +^ Int64.of_int offset)
+ )
+
+ | None ->
+ (* Find functions all silently skip holes, so: *)
+ loop rightend
+ in
+ loop start
let find t ?(start=0L) str =
_find t start 1 str
let bits = bits_of_wordsize (get_wordsize t) in
let e = get_endian t in
let bs = BITSTRING { addr : bits : endian (e) } in
- Bitmatch.string_of_bitstring bs
+ Bitstring.string_of_bitstring bs
*)
-(* XXX bitmatch is missing 'construct_int64_le_unsigned' so we
+(* XXX bitstring is missing 'construct_int64_le_unsigned' so we
* have to force this to 32 bits for the moment.
*)
and string_of_addr t addr =
assert (bits = 32);
let e = get_endian t in
let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
- Bitmatch.string_of_bitstring bs
-
-and addr_of_string t str =
- let bits = bits_of_wordsize (get_wordsize t) in
- let e = get_endian t in
- let bs = Bitmatch.bitstring_of_string str in
- bitmatch bs with
- | { addr : bits : endian (e) } -> addr
- | { _ } -> invalid_arg "addr_of_string"
-
-let get_byte { mappings = mappings } addr =
- let rec loop = function
- | [] -> invalid_arg "get_byte"
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size ->
- let offset = Int64.to_int (addr -^ start) in
- Char.code (Array1.get arr offset)
- | _ :: ms -> loop ms
- in
- loop mappings
-
-(* Take bytes until a condition is not met. This is efficient in that
- * we stay within the same mapping as long as we can.
- *)
-let dowhile { mappings = mappings } addr cond =
- let rec get_next_mapping addr = function
- | [] -> invalid_arg "dowhile"
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size ->
- let offset = Int64.to_int (addr -^ start) in
- let len = Int64.to_int size - offset in
- arr, offset, len
- | _ :: ms -> get_next_mapping addr ms
- in
- let rec loop addr =
- let arr, offset, len = get_next_mapping addr mappings in
- let rec loop2 i =
- if i < len then (
- let c = Array1.get arr (offset+i) in
- if cond c then loop2 (i+1)
- ) else
- loop (addr +^ Int64.of_int len)
- in
- loop2 0
- in
- loop addr
-
-let get_bytes t addr len =
- let str = String.create len in
- let i = ref 0 in
- try
- dowhile t addr (
- fun c ->
- str.[!i] <- c;
- incr i;
- !i < len
- );
- str
- with
- Invalid_argument _ -> invalid_arg "get_bytes"
-
-let get_int32 t addr =
- let e = get_endian t in
- let str = get_bytes t addr 4 in
- let bs = Bitmatch.bitstring_of_string str in
- bitmatch bs with
- | { addr : 32 : endian (e) } -> addr
- | { _ } -> invalid_arg "follow_pointer"
-
-let get_int64 t addr =
- let e = get_endian t in
- let str = get_bytes t addr 8 in
- let bs = Bitmatch.bitstring_of_string str in
- bitmatch bs with
- | { addr : 64 : endian (e) } -> addr
- | { _ } -> invalid_arg "follow_pointer"
-
-let get_C_int = get_int32
-
-let get_C_long t addr =
- let ws = get_wordsize t in
- match ws with
- | W32 -> Int64.of_int32 (get_int32 t addr)
- | W64 -> get_int64 t addr
-
-let get_string t addr =
- let chars = ref [] in
- try
- dowhile t addr (
- fun c ->
- if c <> '\000' then (
- chars := c :: !chars;
- true
- ) else false
- );
- let chars = List.rev !chars in
- let len = List.length chars in
- let str = String.create len in
- let i = ref 0 in
- List.iter (fun c -> str.[!i] <- c; incr i) chars;
- str
- with
- Invalid_argument _ -> invalid_arg "get_string"
-
-let is_string t addr =
- try dowhile t addr (fun c -> c <> '\000'); true
- with Invalid_argument _ -> false
-
-let is_C_identifier t addr =
- let i = ref 0 in
- let r = ref true in
- try
- dowhile t addr (
- fun c ->
- let b =
- if !i = 0 then (
- c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
- ) else (
- if c = '\000' then false
- else (
- if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
- c >= '0' && c <= '9' then
- true
- else (
- r := false;
- false
- )
- )
- ) in
- incr i;
- b
- );
- !r
- with
- Invalid_argument _ -> false
-
-let is_mapped { mappings = mappings } addr =
- let rec loop = function
- | [] -> false
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size -> true
- | _ :: ms -> loop ms
- in
- loop mappings
+ Bitstring.string_of_bitstring bs
let follow_pointer t addr =
let ws = get_wordsize t in
let e = get_endian t in
let bits = bits_of_wordsize ws in
let str = get_bytes t addr (bytes_of_wordsize ws) in
- let bs = Bitmatch.bitstring_of_string str in
+ let bs = Bitstring.bitstring_of_string str in
bitmatch bs with
| { addr : bits : endian (e) } -> addr
| { _ } -> invalid_arg "follow_pointer"