X-Git-Url: http://git.annexia.org/?p=virt-mem.git;a=blobdiff_plain;f=lib%2Fvirt_mem_mmap.ml;h=b9013de7745fe259559ab6f2960db0d5128cd37e;hp=f6edee897073213698f1bf0a0ce51125f60d6f4b;hb=7cb0d37dafdefce84af1230444e6b8ce911d590e;hpb=46037cd89c23b0f94dc691006ee1d9cd0fec24f0 diff --git a/lib/virt_mem_mmap.ml b/lib/virt_mem_mmap.ml index f6edee8..b9013de 100644 --- a/lib/virt_mem_mmap.ml +++ b/lib/virt_mem_mmap.ml @@ -21,63 +21,320 @@ *) open Unix +open Printf open Bigarray open Virt_mem_utils -(* Simple implementation at the moment: Store a list of mappings, - * sorted by start address. We assume that mappings do not overlap. - * We can change the implementation later if we need to. In most cases - * there will only be a small number of mappings (probably 1). - *) -type ('a,'b) t = { - mappings : mapping list; - wordsize : wordsize option; - endian : Bitmatch.endian option; -} -and mapping = { +let debug = true + +(* An address. *) +type addr = int64 + +(* A range of addresses (start and start+size). *) +type interval = addr * addr + +(* A mapping. *) +type mapping = { start : addr; size : addr; (* Bigarray mmap(2)'d region with byte addressing: *) arr : (char,int8_unsigned_elt,c_layout) Array1.t; + (* The order that the mappings were added, 0 for the first mapping, + * 1 for the second mapping, etc. + *) + order : int; } -and addr = int64 +(* A memory map. *) +type ('ws,'e,'hm) t = { + (* List of mappings, kept in reverse order they were added (new + * mappings are added at the head of this list). + *) + mappings : mapping list; + + (* Segment tree for fast access to a mapping at a particular address. + * This is rebuilt each time a new mapping is added. + * NB! If mappings = [], ignore contents of this field. (This is + * enforced by the 'hm phantom type). + *) + tree : (interval * mapping list, interval * mapping list) binary_tree; + + (* Word size, endianness. + * Phantom types enforce that these are set before being used. + *) + wordsize : wordsize; + endian : Bitstring.endian; +} let create () = { mappings = []; - wordsize = None; - endian = None + tree = Leaf ((0L,0L),[]); + wordsize = W32; + endian = Bitstring.LittleEndian; } -let set_wordsize t ws = { t with wordsize = Some ws } +let set_wordsize t ws = { t with wordsize = ws } -let set_endian t e = { t with endian = Some e } +let set_endian t e = { t with endian = e } -let get_wordsize t = Option.get t.wordsize +let get_wordsize t = t.wordsize -let get_endian t = Option.get t.endian +let get_endian t = t.endian -let sort_mappings mappings = - let cmp { start = s1 } { start = s2 } = compare s1 s2 in - List.sort cmp mappings +(* Build the segment tree from the list of mappings. This code + * is taken from virt-df. For an explanation of the process see: + * http://en.wikipedia.org/wiki/Segment_tree + *) +let tree_of_mappings mappings = + (* Construct the list of distinct endpoints. *) + let eps = + List.map + (fun { start = start; size = size } -> [start; start +^ size]) + mappings in + let eps = sort_uniq (List.concat eps) in + + (* Construct the elementary intervals. *) + let elints = + let elints, lastpoint = + List.fold_left ( + fun (elints, prevpoint) point -> + ((point, point) :: (prevpoint, point) :: elints), point + ) ([], 0L) eps in + let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in + List.rev elints in + + if debug then ( + eprintf "elementary intervals (%d in total):\n" (List.length elints); + List.iter ( + fun (startpoint, endpoint) -> + eprintf " %Lx %Lx\n" startpoint endpoint + ) elints + ); + + (* Construct the binary tree of elementary intervals. *) + let tree = + (* Each elementary interval becomes a leaf. *) + let elints = List.map (fun elint -> Leaf elint) elints in + (* Recursively build this into a binary tree. *) + let rec make_layer = function + | [] -> [] + | ([_] as x) -> x + (* Turn pairs of leaves at the bottom level into nodes. *) + | (Leaf _ as a) :: (Leaf _ as b) :: xs -> + let xs = make_layer xs in + Node (a, (), b) :: xs + (* Turn pairs of nodes at higher levels into nodes. *) + | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs -> + let xs = make_layer xs in + Node (left, (), right) :: xs + | Leaf _ :: _ -> assert false (* never happens??? (I think) *) + in + let rec loop = function + | [] -> assert false + | [x] -> x + | xs -> loop (make_layer xs) + in + loop elints in + + if debug then ( + let leaf_printer (startpoint, endpoint) = + sprintf "%Lx-%Lx" startpoint endpoint + in + let node_printer () = "" in + print_binary_tree leaf_printer node_printer tree + ); + + (* Insert the mappings into the tree one by one. *) + let tree = + (* For each node/leaf in the tree, add its interval and an + * empty list which will be used to store the mappings. + *) + let rec interval_tree = function + | Leaf elint -> Leaf (elint, []) + | Node (left, (), right) -> + let left = interval_tree left in + let right = interval_tree right in + let (leftstart, _) = interval_of_node left in + let (_, rightend) = interval_of_node right in + let interval = leftstart, rightend in + Node (left, (interval, []), right) + and interval_of_node = function + | Leaf (elint, _) -> elint + | Node (_, (interval, _), _) -> interval + in + + let tree = interval_tree tree in + (* This should always be true: *) + assert (interval_of_node tree = (0L, Int64.max_int(*XXX*))); + + (* "Contained in" operator. + * 'a <-< b' iff 'a' is a subinterval of 'b'. + * |<---- a ---->| + * |<----------- b ----------->| + *) + let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in + + (* "Intersects" operator. + * 'a /\ b' iff intervals 'a' and 'b' overlap, eg: + * |<---- a ---->| + * |<----------- b ----------->| + *) + let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in + + let rec insert_mapping tree mapping = + let { start = start; size = size } = mapping in + let seginterval = start, start +^ size in + + match tree with + (* Test if we should insert into this leaf or node: *) + | Leaf (interval, mappings) when interval <-< seginterval -> + Leaf (interval, mapping :: mappings) + | Node (left, (interval, mappings), right) + when interval <-< seginterval -> + Node (left, (interval, mapping :: mappings), right) + + | (Leaf _) as leaf -> leaf + + (* Else, should we insert into left or right subtrees? *) + | Node (left, i, right) -> + let left = + if seginterval /\ interval_of_node left then + insert_mapping left mapping + else + left in + let right = + if seginterval /\ interval_of_node right then + insert_mapping right mapping + else + right in + Node (left, i, right) + in + let tree = List.fold_left insert_mapping tree mappings in + tree in + + if debug then ( + let printer ((sp, ep), mappings) = + sprintf "[%Lx-%Lx] " sp ep ^ + String.concat ";" + (List.map (fun { start = start; size = size } -> + sprintf "%Lx+%Lx" start size) + mappings) + in + print_binary_tree printer printer tree + ); -let add_file ({ mappings = mappings } as t) fd addr = - if addr &^ 7L <> 0L then - invalid_arg "add_file: mapping address must be aligned to 8 bytes"; + tree + +let add_mapping ({ mappings = mappings } as t) start size arr = + let order = List.length mappings in + let mapping = { start = start; size = size; arr = arr; order = order } in + let mappings = mapping :: mappings in + let tree = tree_of_mappings mappings in + { t with mappings = mappings; tree = tree } + +let add_file t fd addr = let size = (fstat fd).st_size in (* mmap(2) the file using Bigarray module. *) let arr = Array1.map_file fd char c_layout false size in - (* Create the mapping entry and keep the mappings sorted by start addr. *) - let mappings = - { start = addr; size = Int64.of_int size; arr = arr } :: mappings in - let mappings = sort_mappings mappings in - { t with mappings = mappings } + (* Create the mapping entry. *) + add_mapping t addr (Int64.of_int size) arr + +let add_string ({ mappings = mappings } as t) str addr = + let size = String.length str in + (* Copy the string data to a Bigarray. *) + let arr = Array1.create char c_layout size in + for i = 0 to size-1 do + Array1.set arr i (String.unsafe_get str i) + done; + (* Create the mapping entry. *) + add_mapping t addr (Int64.of_int size) arr let of_file fd addr = let t = create () in add_file t fd addr +let of_string str addr = + let t = create () in + add_string t str addr + +(* Look up an address and get the top-most mapping which contains it. + * This uses the segment tree, so it's fast. The top-most mapping is + * the one with the highest 'order' field. + * + * Warning: This 'hot' code was carefully optimized based on + * feedback from 'gprof'. Avoid fiddling with it. + *) +let rec get_mapping addr = function + | Leaf (_, []) -> None + | Leaf (_, [mapping]) -> Some mapping + | Leaf (_, mappings) -> Some (find_highest_order mappings) + + (* Try to avoid expensive search if node mappings is empty: *) + | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left), + (_, []), + right) -> + let submapping = + if addr < leftend then get_mapping addr left + else get_mapping addr right in + submapping + + (* ... or a singleton: *) + | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left), + (_, [mapping]), + right) -> + let submapping = + if addr < leftend then get_mapping addr left + else get_mapping addr right in + (match submapping with + | None -> Some mapping + | Some submapping -> + Some (if mapping.order > submapping.order then mapping + else submapping) + ) + + (* Normal recursive case: *) + | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left), + (_, mappings), + right) -> + let submapping = + if addr < leftend then get_mapping addr left + else get_mapping addr right in + (match submapping with + | None -> Some (find_highest_order mappings) + | Some submapping -> Some (find_highest_order (submapping :: mappings)) + ) + +and find_highest_order mappings = + List.fold_left ( + fun mapping1 mapping2 -> + if mapping1.order > mapping2.order then mapping1 else mapping2 + ) (List.hd mappings) (List.tl mappings) + +(* Get a single byte. *) +let get_byte { tree = tree } addr = + (* Get the mapping which applies to this address: *) + match get_mapping addr tree with + | Some { start = start; size = size; arr = arr } -> + let offset = Int64.to_int (addr -^ start) in + Char.code (Array1.get arr offset) + | None -> + invalid_arg "get_byte" +(* + let rec loop = function + | [] -> invalid_arg "get_byte" + | { start = start; size = size; arr = arr } :: _ + when start <= addr && addr < start +^ size -> + let offset = Int64.to_int (addr -^ start) in + Char.code (Array1.get arr offset) + | _ :: ms -> loop ms + in + loop mappings +*) + + +(* + (* Find in mappings and return first predicate match. *) let _find_map { mappings = mappings } pred = let rec loop = function @@ -89,6 +346,15 @@ let _find_map { mappings = mappings } pred = in loop mappings +(* The following functions are actually written in C + * because memmem(3) is likely to be much faster than anything + * we could write in OCaml. + * + * Also OCaml bigarrays are specifically designed to be accessed + * easily from C: + * http://caml.inria.fr/pub/docs/manual-ocaml/manual043.html + *) +(* (* Array+offset = string? *) let string_at arr offset str strlen = let j = ref offset in @@ -125,6 +391,10 @@ let _find_in start align str arr = loop () ) else Some start +*) +external _find_in : + int -> int -> string -> (char,int8_unsigned_elt,c_layout) Array1.t -> + int option = "virt_mem_mmap_find_in" (* Generic find function. *) let _find t start align str = @@ -181,9 +451,9 @@ and string_of_addr t addr = let bits = bits_of_wordsize (get_wordsize t) in let e = get_endian t in let bs = BITSTRING { addr : bits : endian (e) } in - Bitmatch.string_of_bitstring bs + Bitstring.string_of_bitstring bs *) -(* XXX bitmatch is missing 'construct_int64_le_unsigned' so we +(* XXX bitstring is missing 'construct_int64_le_unsigned' so we * have to force this to 32 bits for the moment. *) and string_of_addr t addr = @@ -191,27 +461,16 @@ and string_of_addr t addr = assert (bits = 32); let e = get_endian t in let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in - Bitmatch.string_of_bitstring bs + Bitstring.string_of_bitstring bs and addr_of_string t str = let bits = bits_of_wordsize (get_wordsize t) in let e = get_endian t in - let bs = Bitmatch.bitstring_of_string str in + let bs = Bitstring.bitstring_of_string str in bitmatch bs with | { addr : bits : endian (e) } -> addr | { _ } -> invalid_arg "addr_of_string" -let get_byte { mappings = mappings } addr = - let rec loop = function - | [] -> invalid_arg "get_byte" - | { start = start; size = size; arr = arr } :: _ - when start <= addr && addr < start +^ size -> - let offset = Int64.to_int (addr -^ start) in - Char.code (Array1.get arr offset) - | _ :: ms -> loop ms - in - loop mappings - (* Take bytes until a condition is not met. This is efficient in that * we stay within the same mapping as long as we can. *) @@ -252,6 +511,30 @@ let get_bytes t addr len = with Invalid_argument _ -> invalid_arg "get_bytes" +let get_int32 t addr = + let e = get_endian t in + let str = get_bytes t addr 4 in + let bs = Bitstring.bitstring_of_string str in + bitmatch bs with + | { addr : 32 : endian (e) } -> addr + | { _ } -> invalid_arg "follow_pointer" + +let get_int64 t addr = + let e = get_endian t in + let str = get_bytes t addr 8 in + let bs = Bitstring.bitstring_of_string str in + bitmatch bs with + | { addr : 64 : endian (e) } -> addr + | { _ } -> invalid_arg "follow_pointer" + +let get_C_int = get_int32 + +let get_C_long t addr = + let ws = get_wordsize t in + match ws with + | W32 -> Int64.of_int32 (get_int32 t addr) + | W64 -> get_int64 t addr + let get_string t addr = let chars = ref [] in try @@ -317,7 +600,7 @@ let follow_pointer t addr = let e = get_endian t in let bits = bits_of_wordsize ws in let str = get_bytes t addr (bytes_of_wordsize ws) in - let bs = Bitmatch.bitstring_of_string str in + let bs = Bitstring.bitstring_of_string str in bitmatch bs with | { addr : bits : endian (e) } -> addr | { _ } -> invalid_arg "follow_pointer" @@ -334,3 +617,12 @@ let align t addr = let ws = get_wordsize t in let mask = Int64.of_int (bytes_of_wordsize ws - 1) in (addr +^ mask) &^ (Int64.lognot mask) + +let map { mappings = mappings } f = + List.map (fun { start = start; size = size } -> f start size) mappings + +let iter t f = + ignore (map t (fun start size -> let () = f start size in ())) + +let nr_mappings { mappings = mappings } = List.length mappings +*)