*)
open Unix
+open Printf
open Bigarray
open Virt_mem_utils
-(* Simple implementation at the moment: Store a list of mappings,
- * sorted by start address. We assume that mappings do not overlap.
- * We can change the implementation later if we need to. In most cases
- * there will only be a small number of mappings (probably 1).
- *)
-type ('a,'b) t = {
- mappings : mapping list;
- wordsize : wordsize option;
- endian : Bitmatch.endian option;
-}
-and mapping = {
+let debug = true
+
+(* An address. *)
+type addr = int64
+
+(* A range of addresses (start and start+size). *)
+type interval = addr * addr
+
+(* A mapping. *)
+type mapping = {
start : addr;
size : addr;
(* Bigarray mmap(2)'d region with byte addressing: *)
arr : (char,int8_unsigned_elt,c_layout) Array1.t;
+ (* The order that the mappings were added, 0 for the first mapping,
+ * 1 for the second mapping, etc.
+ *)
+ order : int;
}
-and addr = int64
+(* A memory map. *)
+type ('ws,'e,'hm) t = {
+ (* List of mappings, kept in reverse order they were added (new
+ * mappings are added at the head of this list).
+ *)
+ mappings : mapping list;
+
+ (* Segment tree for fast access to a mapping at a particular address.
+ * This is rebuilt each time a new mapping is added.
+ * NB! If mappings = [], ignore contents of this field. (This is
+ * enforced by the 'hm phantom type).
+ *)
+ tree : (interval * mapping list, interval * mapping list) binary_tree;
+
+ (* Word size, endianness.
+ * Phantom types enforce that these are set before being used.
+ *)
+ wordsize : wordsize;
+ endian : Bitstring.endian;
+}
let create () = {
mappings = [];
- wordsize = None;
- endian = None
+ tree = Leaf ((0L,0L),[]);
+ wordsize = W32;
+ endian = Bitstring.LittleEndian;
}
-let set_wordsize t ws = { t with wordsize = Some ws }
+let set_wordsize t ws = { t with wordsize = ws }
+
+let set_endian t e = { t with endian = e }
+
+let get_wordsize t = t.wordsize
+
+let get_endian t = t.endian
+
+(* Build the segment tree from the list of mappings. This code
+ * is taken from virt-df. For an explanation of the process see:
+ * http://en.wikipedia.org/wiki/Segment_tree
+ *)
+let tree_of_mappings mappings =
+ (* Construct the list of distinct endpoints. *)
+ let eps =
+ List.map
+ (fun { start = start; size = size } -> [start; start +^ size])
+ mappings in
+ let eps = sort_uniq (List.concat eps) in
+
+ (* Construct the elementary intervals. *)
+ let elints =
+ let elints, lastpoint =
+ List.fold_left (
+ fun (elints, prevpoint) point ->
+ ((point, point) :: (prevpoint, point) :: elints), point
+ ) ([], 0L) eps in
+ let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
+ List.rev elints in
+
+ if debug then (
+ eprintf "elementary intervals (%d in total):\n" (List.length elints);
+ List.iter (
+ fun (startpoint, endpoint) ->
+ eprintf " %Lx %Lx\n" startpoint endpoint
+ ) elints
+ );
+
+ (* Construct the binary tree of elementary intervals. *)
+ let tree =
+ (* Each elementary interval becomes a leaf. *)
+ let elints = List.map (fun elint -> Leaf elint) elints in
+ (* Recursively build this into a binary tree. *)
+ let rec make_layer = function
+ | [] -> []
+ | ([_] as x) -> x
+ (* Turn pairs of leaves at the bottom level into nodes. *)
+ | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
+ let xs = make_layer xs in
+ Node (a, (), b) :: xs
+ (* Turn pairs of nodes at higher levels into nodes. *)
+ | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
+ let xs = make_layer xs in
+ Node (left, (), right) :: xs
+ | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
+ in
+ let rec loop = function
+ | [] -> assert false
+ | [x] -> x
+ | xs -> loop (make_layer xs)
+ in
+ loop elints in
-let set_endian t e = { t with endian = Some e }
+ if debug then (
+ let leaf_printer (startpoint, endpoint) =
+ sprintf "%Lx-%Lx" startpoint endpoint
+ in
+ let node_printer () = "" in
+ print_binary_tree leaf_printer node_printer tree
+ );
+
+ (* Insert the mappings into the tree one by one. *)
+ let tree =
+ (* For each node/leaf in the tree, add its interval and an
+ * empty list which will be used to store the mappings.
+ *)
+ let rec interval_tree = function
+ | Leaf elint -> Leaf (elint, [])
+ | Node (left, (), right) ->
+ let left = interval_tree left in
+ let right = interval_tree right in
+ let (leftstart, _) = interval_of_node left in
+ let (_, rightend) = interval_of_node right in
+ let interval = leftstart, rightend in
+ Node (left, (interval, []), right)
+ and interval_of_node = function
+ | Leaf (elint, _) -> elint
+ | Node (_, (interval, _), _) -> interval
+ in
-let get_wordsize t = Option.get t.wordsize
+ let tree = interval_tree tree in
+ (* This should always be true: *)
+ assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
+
+ (* "Contained in" operator.
+ * 'a <-< b' iff 'a' is a subinterval of 'b'.
+ * |<---- a ---->|
+ * |<----------- b ----------->|
+ *)
+ let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
+
+ (* "Intersects" operator.
+ * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
+ * |<---- a ---->|
+ * |<----------- b ----------->|
+ *)
+ let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
+
+ let rec insert_mapping tree mapping =
+ let { start = start; size = size } = mapping in
+ let seginterval = start, start +^ size in
+
+ match tree with
+ (* Test if we should insert into this leaf or node: *)
+ | Leaf (interval, mappings) when interval <-< seginterval ->
+ Leaf (interval, mapping :: mappings)
+ | Node (left, (interval, mappings), right)
+ when interval <-< seginterval ->
+ Node (left, (interval, mapping :: mappings), right)
+
+ | (Leaf _) as leaf -> leaf
+
+ (* Else, should we insert into left or right subtrees? *)
+ | Node (left, i, right) ->
+ let left =
+ if seginterval /\ interval_of_node left then
+ insert_mapping left mapping
+ else
+ left in
+ let right =
+ if seginterval /\ interval_of_node right then
+ insert_mapping right mapping
+ else
+ right in
+ Node (left, i, right)
+ in
+ let tree = List.fold_left insert_mapping tree mappings in
+ tree in
+
+ if debug then (
+ let printer ((sp, ep), mappings) =
+ sprintf "[%Lx-%Lx] " sp ep ^
+ String.concat ";"
+ (List.map (fun { start = start; size = size } ->
+ sprintf "%Lx+%Lx" start size)
+ mappings)
+ in
+ print_binary_tree printer printer tree
+ );
-let get_endian t = Option.get t.endian
+ tree
-let sort_mappings mappings =
- let cmp { start = s1 } { start = s2 } = compare s1 s2 in
- List.sort cmp mappings
+let add_mapping ({ mappings = mappings } as t) start size arr =
+ let order = List.length mappings in
+ let mapping = { start = start; size = size; arr = arr; order = order } in
+ let mappings = mapping :: mappings in
+ let tree = tree_of_mappings mappings in
+ { t with mappings = mappings; tree = tree }
-let add_file ({ mappings = mappings } as t) fd addr =
- if addr &^ 7L <> 0L then
- invalid_arg "add_file: mapping address must be aligned to 8 bytes";
+let add_file t fd addr =
let size = (fstat fd).st_size in
(* mmap(2) the file using Bigarray module. *)
let arr = Array1.map_file fd char c_layout false size in
- (* Create the mapping entry and keep the mappings sorted by start addr. *)
- let mappings =
- { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
- let mappings = sort_mappings mappings in
- { t with mappings = mappings }
-
-let of_file fd addr =
- let t = create () in
- add_file t fd addr
+ (* Create the mapping entry. *)
+ add_mapping t addr (Int64.of_int size) arr
let add_string ({ mappings = mappings } as t) str addr =
- if addr &^ 7L <> 0L then
- invalid_arg "add_file: mapping address must be aligned to 8 bytes";
let size = String.length str in
(* Copy the string data to a Bigarray. *)
let arr = Array1.create char c_layout size in
for i = 0 to size-1 do
Array1.set arr i (String.unsafe_get str i)
done;
- (* Create the mapping entry and keep the mappings sorted by start addr. *)
- let mappings =
- { start = addr; size = Int64.of_int size; arr = arr } :: mappings in
- let mappings = sort_mappings mappings in
- { t with mappings = mappings }
+ (* Create the mapping entry. *)
+ add_mapping t addr (Int64.of_int size) arr
+
+let of_file fd addr =
+ let t = create () in
+ add_file t fd addr
let of_string str addr =
let t = create () in
add_string t str addr
+(* Look up an address and get the top-most mapping which contains it.
+ * This uses the segment tree, so it's fast. The top-most mapping is
+ * the one with the highest 'order' field.
+ *
+ * Warning: This 'hot' code was carefully optimized based on
+ * feedback from 'gprof'. Avoid fiddling with it.
+ *)
+let rec get_mapping addr = function
+ | Leaf (_, []) -> None
+ | Leaf (_, [mapping]) -> Some mapping
+ | Leaf (_, mappings) -> Some (find_highest_order mappings)
+
+ (* Try to avoid expensive search if node mappings is empty: *)
+ | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+ (_, []),
+ right) ->
+ let submapping =
+ if addr < leftend then get_mapping addr left
+ else get_mapping addr right in
+ submapping
+
+ (* ... or a singleton: *)
+ | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+ (_, [mapping]),
+ right) ->
+ let submapping =
+ if addr < leftend then get_mapping addr left
+ else get_mapping addr right in
+ (match submapping with
+ | None -> Some mapping
+ | Some submapping ->
+ Some (if mapping.order > submapping.order then mapping
+ else submapping)
+ )
+
+ (* Normal recursive case: *)
+ | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
+ (_, mappings),
+ right) ->
+ let submapping =
+ if addr < leftend then get_mapping addr left
+ else get_mapping addr right in
+ (match submapping with
+ | None -> Some (find_highest_order mappings)
+ | Some submapping -> Some (find_highest_order (submapping :: mappings))
+ )
+
+and find_highest_order mappings =
+ List.fold_left (
+ fun mapping1 mapping2 ->
+ if mapping1.order > mapping2.order then mapping1 else mapping2
+ ) (List.hd mappings) (List.tl mappings)
+
+(* Get a single byte. *)
+let get_byte { tree = tree } addr =
+ (* Get the mapping which applies to this address: *)
+ match get_mapping addr tree with
+ | Some { start = start; size = size; arr = arr } ->
+ let offset = Int64.to_int (addr -^ start) in
+ Char.code (Array1.get arr offset)
+ | None ->
+ invalid_arg "get_byte"
+(*
+ let rec loop = function
+ | [] -> invalid_arg "get_byte"
+ | { start = start; size = size; arr = arr } :: _
+ when start <= addr && addr < start +^ size ->
+ let offset = Int64.to_int (addr -^ start) in
+ Char.code (Array1.get arr offset)
+ | _ :: ms -> loop ms
+ in
+ loop mappings
+*)
+
+
+(*
+
(* Find in mappings and return first predicate match. *)
let _find_map { mappings = mappings } pred =
let rec loop = function
let bits = bits_of_wordsize (get_wordsize t) in
let e = get_endian t in
let bs = BITSTRING { addr : bits : endian (e) } in
- Bitmatch.string_of_bitstring bs
+ Bitstring.string_of_bitstring bs
*)
-(* XXX bitmatch is missing 'construct_int64_le_unsigned' so we
+(* XXX bitstring is missing 'construct_int64_le_unsigned' so we
* have to force this to 32 bits for the moment.
*)
and string_of_addr t addr =
assert (bits = 32);
let e = get_endian t in
let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
- Bitmatch.string_of_bitstring bs
+ Bitstring.string_of_bitstring bs
and addr_of_string t str =
let bits = bits_of_wordsize (get_wordsize t) in
let e = get_endian t in
- let bs = Bitmatch.bitstring_of_string str in
+ let bs = Bitstring.bitstring_of_string str in
bitmatch bs with
| { addr : bits : endian (e) } -> addr
| { _ } -> invalid_arg "addr_of_string"
-let get_byte { mappings = mappings } addr =
- let rec loop = function
- | [] -> invalid_arg "get_byte"
- | { start = start; size = size; arr = arr } :: _
- when start <= addr && addr < start +^ size ->
- let offset = Int64.to_int (addr -^ start) in
- Char.code (Array1.get arr offset)
- | _ :: ms -> loop ms
- in
- loop mappings
-
(* Take bytes until a condition is not met. This is efficient in that
* we stay within the same mapping as long as we can.
*)
let get_int32 t addr =
let e = get_endian t in
let str = get_bytes t addr 4 in
- let bs = Bitmatch.bitstring_of_string str in
+ let bs = Bitstring.bitstring_of_string str in
bitmatch bs with
| { addr : 32 : endian (e) } -> addr
| { _ } -> invalid_arg "follow_pointer"
let get_int64 t addr =
let e = get_endian t in
let str = get_bytes t addr 8 in
- let bs = Bitmatch.bitstring_of_string str in
+ let bs = Bitstring.bitstring_of_string str in
bitmatch bs with
| { addr : 64 : endian (e) } -> addr
| { _ } -> invalid_arg "follow_pointer"
let e = get_endian t in
let bits = bits_of_wordsize ws in
let str = get_bytes t addr (bytes_of_wordsize ws) in
- let bs = Bitmatch.bitstring_of_string str in
+ let bs = Bitstring.bitstring_of_string str in
bitmatch bs with
| { addr : bits : endian (e) } -> addr
| { _ } -> invalid_arg "follow_pointer"
let ws = get_wordsize t in
let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
(addr +^ mask) &^ (Int64.lognot mask)
+
+let map { mappings = mappings } f =
+ List.map (fun { start = start; size = size } -> f start size) mappings
+
+let iter t f =
+ ignore (map t (fun start size -> let () = f start size in ()))
+
+let nr_mappings { mappings = mappings } = List.length mappings
+*)