X-Git-Url: http://git.annexia.org/?p=virt-mem.git;a=blobdiff_plain;f=lib%2Fvirt_mem_mmap.ml;h=c323c1a0096c3e7e97252c8a704a1248f77a2066;hp=7401e17218ea24bbf4bc30402c1825c204646017;hb=2e1de51e35bea53ebece1a6fd6d6970534f4cbe9;hpb=5ce06c3326a2672e82dc656b35eb7a3e6616539a diff --git a/lib/virt_mem_mmap.ml b/lib/virt_mem_mmap.ml index 7401e17..c323c1a 100644 --- a/lib/virt_mem_mmap.ml +++ b/lib/virt_mem_mmap.ml @@ -21,74 +21,473 @@ *) open Unix +open Printf open Bigarray open Virt_mem_utils -(* Simple implementation at the moment: Store a list of mappings, - * sorted by start address. We assume that mappings do not overlap. - * We can change the implementation later if we need to. In most cases - * there will only be a small number of mappings (probably 1). - *) -type ('a,'b) t = { - mappings : mapping list; - wordsize : wordsize option; - endian : Bitmatch.endian option; -} -and mapping = { +let debug = false + +(* An address. *) +type addr = int64 + +(* A range of addresses (start and start+size). *) +type interval = addr * addr + +(* A mapping. *) +type mapping = { start : addr; size : addr; (* Bigarray mmap(2)'d region with byte addressing: *) arr : (char,int8_unsigned_elt,c_layout) Array1.t; + (* The order that the mappings were added, 0 for the first mapping, + * 1 for the second mapping, etc. + *) + order : int; } -and addr = int64 +(* A memory map. *) +type ('ws,'e,'hm) t = { + (* List of mappings, kept in reverse order they were added (new + * mappings are added at the head of this list). + *) + mappings : mapping list; + + (* Segment tree for fast access to a mapping at a particular address. + * This is rebuilt each time a new mapping is added. + * NB! If mappings = [], ignore contents of this field. (This is + * enforced by the 'hm phantom type). + *) + tree : (interval * mapping option, interval * mapping option) binary_tree; + + (* Word size, endianness. + * Phantom types enforce that these are set before being used. + *) + wordsize : wordsize; + endian : Bitstring.endian; +} let create () = { mappings = []; - wordsize = None; - endian = None + tree = Leaf ((0L,0L),None); + wordsize = W32; + endian = Bitstring.LittleEndian; } -let set_wordsize t ws = { t with wordsize = Some ws } +let set_wordsize t ws = { t with wordsize = ws } -let set_endian t e = { t with endian = Some e } +let set_endian t e = { t with endian = e } -let get_wordsize t = Option.get t.wordsize +let get_wordsize t = t.wordsize -let get_endian t = Option.get t.endian +let get_endian t = t.endian -let sort_mappings mappings = - let cmp { start = s1 } { start = s2 } = compare s1 s2 in - List.sort cmp mappings +(* Build the segment tree from the list of mappings. This code + * is taken from virt-df. For an explanation of the process see: + * http://en.wikipedia.org/wiki/Segment_tree + * + * See also the 'get_mapping' function below which uses this tree + * to do fast lookups. + *) +let tree_of_mappings mappings = + (* Construct the list of distinct endpoints. *) + let eps = + List.map + (fun { start = start; size = size } -> [start; start +^ size]) + mappings in + let eps = sort_uniq (List.concat eps) in + + (* Construct the elementary intervals. *) + let elints = + let elints, lastpoint = + List.fold_left ( + fun (elints, prevpoint) point -> + ((point, point) :: (prevpoint, point) :: elints), point + ) ([], 0L) eps in + let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in + List.rev elints in + + if debug then ( + eprintf "elementary intervals (%d in total):\n" (List.length elints); + List.iter ( + fun (startpoint, endpoint) -> + eprintf " %Lx %Lx\n" startpoint endpoint + ) elints + ); + + (* Construct the binary tree of elementary intervals. *) + let tree = + (* Each elementary interval becomes a leaf. *) + let elints = List.map (fun elint -> Leaf elint) elints in + (* Recursively build this into a binary tree. *) + let rec make_layer = function + | [] -> [] + | ([_] as x) -> x + (* Turn pairs of leaves at the bottom level into nodes. *) + | (Leaf _ as a) :: (Leaf _ as b) :: xs -> + let xs = make_layer xs in + Node (a, (), b) :: xs + (* Turn pairs of nodes at higher levels into nodes. *) + | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs -> + let xs = make_layer xs in + Node (left, (), right) :: xs + | Leaf _ :: _ -> assert false (* never happens??? (I think) *) + in + let rec loop = function + | [] -> assert false + | [x] -> x + | xs -> loop (make_layer xs) + in + loop elints in -let add_file ({ mappings = mappings } as t) fd addr = - if addr &^ 7L <> 0L then - invalid_arg "add_file: mapping address must be aligned to 8 bytes"; + if debug then ( + let leaf_printer (startpoint, endpoint) = + sprintf "%Lx-%Lx" startpoint endpoint + in + let node_printer () = "" in + print_binary_tree leaf_printer node_printer tree + ); + + (* Insert the mappings into the tree one by one. *) + let tree = + (* For each node/leaf in the tree, add its interval and an + * empty list which will be used to store the mappings. + *) + let rec interval_tree = function + | Leaf elint -> Leaf (elint, None) + | Node (left, (), right) -> + let left = interval_tree left in + let right = interval_tree right in + let (leftstart, _) = interval_of_node left in + let (_, rightend) = interval_of_node right in + let interval = leftstart, rightend in + Node (left, (interval, None), right) + and interval_of_node = function + | Leaf (elint, _) -> elint + | Node (_, (interval, _), _) -> interval + in + + let tree = interval_tree tree in + (* This should always be true: *) + assert (interval_of_node tree = (0L, Int64.max_int(*XXX*))); + + (* "Contained in" operator. + * 'a <-< b' iff 'a' is a subinterval of 'b'. + * |<---- a ---->| + * |<----------- b ----------->| + *) + let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in + + (* "Intersects" operator. + * 'a /\ b' iff intervals 'a' and 'b' overlap, eg: + * |<---- a ---->| + * |<----------- b ----------->| + *) + let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in + + let rec insert_mapping tree mapping = + let { start = start; size = size } = mapping in + let seginterval = start, start +^ size in + + match tree with + (* Test if we should insert into this leaf or node: *) + | Leaf (interval, None) when interval <-< seginterval -> + Leaf (interval, Some mapping) + | Leaf (interval, Some oldmapping) when interval <-< seginterval -> + let mapping = + if oldmapping.order > mapping.order then oldmapping else mapping in + Leaf (interval, Some mapping) + + | Node (left, (interval, None), right) when interval <-< seginterval -> + Node (left, (interval, Some mapping), right) + + | Node (left, (interval, Some oldmapping), right) + when interval <-< seginterval -> + let mapping = + if oldmapping.order > mapping.order then oldmapping else mapping in + Node (left, (interval, Some mapping), right) + + | (Leaf _) as leaf -> leaf + + (* Else, should we insert into left or right subtrees? *) + | Node (left, i, right) -> + let left = + if seginterval /\ interval_of_node left then + insert_mapping left mapping + else + left in + let right = + if seginterval /\ interval_of_node right then + insert_mapping right mapping + else + right in + Node (left, i, right) + in + let tree = List.fold_left insert_mapping tree mappings in + tree in + + if debug then ( + let printer ((sp, ep), mapping) = + sprintf "[%Lx-%Lx] " sp ep ^ + match mapping with + | None -> "(none)" + | Some { start = start; size = size; order = order } -> + sprintf "%Lx..%Lx(%d)" start (start+^size-^1L) order + in + print_binary_tree printer printer tree + ); + + tree + +let add_mapping ({ mappings = mappings } as t) start size arr = + let order = List.length mappings in + let mapping = { start = start; size = size; arr = arr; order = order } in + let mappings = mapping :: mappings in + let tree = tree_of_mappings mappings in + { t with mappings = mappings; tree = tree } + +let add_file t fd addr = let size = (fstat fd).st_size in (* mmap(2) the file using Bigarray module. *) let arr = Array1.map_file fd char c_layout false size in - (* Create the mapping entry and keep the mappings sorted by start addr. *) - let mappings = - { start = addr; size = Int64.of_int size; arr = arr } :: mappings in - let mappings = sort_mappings mappings in - { t with mappings = mappings } + (* Create the mapping entry. *) + add_mapping t addr (Int64.of_int size) arr + +let add_string ({ mappings = mappings } as t) str addr = + let size = String.length str in + (* Copy the string data to a Bigarray. *) + let arr = Array1.create char c_layout size in + for i = 0 to size-1 do + Array1.set arr i (String.unsafe_get str i) + done; + (* Create the mapping entry. *) + add_mapping t addr (Int64.of_int size) arr let of_file fd addr = let t = create () in add_file t fd addr -(* Find in mappings and return first predicate match. *) -let _find_map { mappings = mappings } pred = - let rec loop = function - | [] -> None - | m :: ms -> - match pred m with - | Some n -> Some n - | None -> loop ms +let of_string str addr = + let t = create () in + add_string t str addr + +(* 'get_mapping' is the crucial, fast lookup function for address -> mapping. + * It searches the tree (hence fast) to work out the topmost mapping which + * applies to an address. + * + * Returns (rightend * mapping option) + * where 'mapping option' is the mapping (or None if it's a hole) + * and 'rightend' is the next address at which there is a different + * mapping/hole. In other words, this mapping result is good for + * addresses [addr .. rightend-1]. + *) +let rec get_mapping addr = function + | Leaf ((_, rightend), mapping) -> rightend, mapping + + | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left), + (_, None), + right) -> + let subrightend, submapping = + if addr < leftend then get_mapping addr left + else get_mapping addr right in + subrightend, submapping + + | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left), + (_, Some mapping), + right) -> + let subrightend, submapping = + if addr < leftend then get_mapping addr left + else get_mapping addr right in + (match submapping with + | None -> subrightend, Some mapping + | Some submapping -> + subrightend, + Some (if mapping.order > submapping.order then mapping + else submapping) + ) + +(* Use the tree to quickly check if an address is mapped (returns false + * if it's a hole). + *) +let is_mapped { mappings = mappings; tree = tree } addr = + (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *) + match mappings with + | [] -> false + | _ -> + let _, mapping = get_mapping addr tree in + mapping <> None + +(* Get a single byte. *) +let get_byte { tree = tree } addr = + (* Get the mapping which applies to this address: *) + let _, mapping = get_mapping addr tree in + match mapping with + | Some { start = start; size = size; arr = arr } -> + let offset = Int64.to_int (addr -^ start) in + Char.code (Array1.get arr offset) + | None -> + invalid_arg "get_byte" + +(* Get a range of bytes, possibly across several intervals. *) +let get_bytes { tree = tree } addr len = + let str = String.create len in + + let rec loop addr pos len = + if len > 0 then ( + let rightend, mapping = get_mapping addr tree in + match mapping with + | Some { start = start; size = size; arr = arr } -> + (* Offset within this mapping. *) + let offset = Int64.to_int (addr -^ start) in + (* Number of bytes to read before we either get to the end + * of our 'len' or we fall off the end of this interval. + *) + let n = min len (Int64.to_int (rightend -^ addr)) in + for i = 0 to n-1 do + String.unsafe_set str (pos + i) (Array1.get arr (offset + i)) + done; + let len = len - n in + loop (addr +^ Int64.of_int n) (pos + n) len + + | None -> + invalid_arg "get_bytes" + ) + in + loop addr 0 len; + + str + +let get_int32 t addr = + let e = get_endian t in + let str = get_bytes t addr 4 in + let bs = Bitstring.bitstring_of_string str in + bitmatch bs with + | { addr : 32 : endian (e) } -> addr + | { _ } -> invalid_arg "get_int32" + +let get_int64 t addr = + let e = get_endian t in + let str = get_bytes t addr 8 in + let bs = Bitstring.bitstring_of_string str in + bitmatch bs with + | { addr : 64 : endian (e) } -> addr + | { _ } -> invalid_arg "get_int64" + +let get_C_int = get_int32 + +let get_C_long t addr = + let ws = get_wordsize t in + match ws with + | W32 -> Int64.of_int32 (get_int32 t addr) + | W64 -> get_int64 t addr + +(* Take bytes until a condition is not met. This is efficient + * in that we stay within the same mapping as long as we can. + * + * If we hit a hole, raises Invalid_argument "dowhile". + *) +let dowhile { tree = tree } addr cond = + let rec loop addr = + let rightend, mapping = get_mapping addr tree in + match mapping with + | Some { start = start; size = size; arr = arr } -> + (* Offset within this mapping. *) + let offset = Int64.to_int (addr -^ start) in + (* Number of bytes before we fall off the end of this interval. *) + let n = Int64.to_int (rightend -^ addr) in + + let rec loop2 addr offset n = + if n > 0 then ( + let c = Array1.get arr offset in + if cond addr c then + loop2 (addr +^ 1L) (offset + 1) (n - 1) + else + false (* stop now, finish outer loop too *) + ) + else true (* fell off the end, so continue outer loop *) + in + if loop2 addr offset n then + loop (addr +^ Int64.of_int n) + + | None -> + invalid_arg "dowhile" in - loop mappings + loop addr + +let is_mapped_range ({ mappings = mappings } as t) addr size = + match mappings with + (* NB: No [`HasMapping] in the type so we have to check mappings <> []. *) + | [] -> false + | _ -> + (* Quick and dirty. It's possible to make a much faster + * implementation of this which doesn't call the closure for every + * byte. + *) + let size = ref size in + try dowhile t addr (fun _ _ -> decr size; !size > 0); true + with Invalid_argument "dowhile" -> false + +(* Get a string, ending at ASCII NUL character. *) +let get_string t addr = + let chars = ref [] in + try + dowhile t addr ( + fun _ c -> + if c <> '\000' then ( + chars := c :: !chars; + true + ) else false + ); + let chars = List.rev !chars in + let len = List.length chars in + let str = String.create len in + let i = ref 0 in + List.iter (fun c -> String.unsafe_set str !i c; incr i) chars; + str + with + Invalid_argument _ -> invalid_arg "get_string" + +let is_string t addr = + try dowhile t addr (fun _ c -> c <> '\000'); true + with Invalid_argument _ -> false + +let is_C_identifier t addr = + let i = ref 0 in + let r = ref true in + try + dowhile t addr ( + fun _ c -> + let b = + if !i = 0 then ( + c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' + ) else ( + if c = '\000' then false + else ( + if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || + c >= '0' && c <= '9' then + true + else ( + r := false; + false + ) + ) + ) in + incr i; + b + ); + !r + with + Invalid_argument _ -> false +(* The following functions are actually written in C + * because memmem(3) is likely to be much faster than anything + * we could write in OCaml. + * + * Also OCaml bigarrays are specifically designed to be accessed + * easily from C: + * http://caml.inria.fr/pub/docs/manual-ocaml/manual043.html + *) +(* (* Array+offset = string? *) let string_at arr offset str strlen = let j = ref offset in @@ -125,26 +524,30 @@ let _find_in start align str arr = loop () ) else Some start +*) +external _find_in : + int -> int -> string -> (char,int8_unsigned_elt,c_layout) Array1.t -> + int option = "virt_mem_mmap_find_in" (* Generic find function. *) -let _find t start align str = - _find_map t ( - fun { start = mstart; size = msize; arr = arr } -> - if mstart >= start then ( - (* Check this mapping from the beginning. *) - match _find_in 0 align str arr with - | Some offset -> Some (mstart +^ Int64.of_int offset) - | None -> None - ) - else if mstart < start && start <= mstart+^msize then ( - (* Check this mapping from somewhere in the middle. *) - let offset = Int64.to_int (start -^ mstart) in - match _find_in offset align str arr with - | Some offset -> Some (mstart +^ Int64.of_int offset) +let _find { tree = tree } start align str = + let rec loop addr = + let rightend, mapping = get_mapping addr tree in + match mapping with + | Some { start = start; size = size; arr = arr } -> + (* Offset within this mapping. *) + let offset = Int64.to_int (addr -^ start) in + + (match _find_in offset align str arr with | None -> None - ) - else None - ) + | Some offset -> Some (start +^ Int64.of_int offset) + ) + + | None -> + (* Find functions all silently skip holes, so: *) + loop rightend + in + loop start let find t ?(start=0L) str = _find t start 1 str @@ -181,9 +584,9 @@ and string_of_addr t addr = let bits = bits_of_wordsize (get_wordsize t) in let e = get_endian t in let bs = BITSTRING { addr : bits : endian (e) } in - Bitmatch.string_of_bitstring bs + Bitstring.string_of_bitstring bs *) -(* XXX bitmatch is missing 'construct_int64_le_unsigned' so we +(* XXX bitstring is missing 'construct_int64_le_unsigned' so we * have to force this to 32 bits for the moment. *) and string_of_addr t addr = @@ -191,124 +594,14 @@ and string_of_addr t addr = assert (bits = 32); let e = get_endian t in let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in - Bitmatch.string_of_bitstring bs - -and addr_of_string t str = - let bits = bits_of_wordsize (get_wordsize t) in - let e = get_endian t in - let bs = Bitmatch.bitstring_of_string str in - bitmatch bs with - | { addr : bits : endian (e) } -> addr - | { _ } -> invalid_arg "addr_of_string" - -let get_byte { mappings = mappings } addr = - let rec loop = function - | [] -> invalid_arg "get_byte" - | { start = start; size = size; arr = arr } :: _ - when start <= addr && addr < size -> - let offset = Int64.to_int (addr -^ start) in - Char.code (Array1.get arr offset) - | _ :: ms -> loop ms - in - loop mappings - -(* Take bytes until a condition is not met. This is efficient in that - * we stay within the same mapping as long as we can. - *) -let dowhile { mappings = mappings } addr cond = - let rec get_next_mapping addr = function - | [] -> invalid_arg "dowhile" - | { start = start; size = size; arr = arr } :: _ - when start <= addr && addr < start +^ size -> - let offset = Int64.to_int (addr -^ start) in - let len = Int64.to_int size - offset in - arr, offset, len - | _ :: ms -> get_next_mapping addr ms - in - let rec loop addr = - let arr, offset, len = get_next_mapping addr mappings in - let rec loop2 i = - if i < len then ( - let c = Array1.get arr (offset+i) in - if cond c then loop2 (i+1) - ) else - loop (addr +^ Int64.of_int len) - in - loop2 0 - in - loop addr - -let get_bytes t addr len = - let str = String.create len in - let i = ref 0 in - try - dowhile t addr ( - fun c -> - str.[!i] <- c; - incr i; - !i < len - ); - str - with - Invalid_argument _ -> invalid_arg "get_bytes" - -let get_string t addr = - let chars = ref [] in - try - dowhile t addr ( - fun c -> - if c <> '\000' then ( - chars := c :: !chars; - true - ) else false - ); - let chars = List.rev !chars in - let len = List.length chars in - let str = String.create len in - let i = ref 0 in - List.iter (fun c -> str.[!i] <- c; incr i) chars; - str - with - Invalid_argument _ -> invalid_arg "get_string" - -let is_string t addr = - try dowhile t addr (fun c -> c <> '\000'); true - with Invalid_argument _ -> false - -let is_C_identifier t addr = - let i = ref 0 in - let r = ref true in - try - dowhile t addr ( - fun c -> - let b = - if !i = 0 then ( - c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' - ) else ( - if c = '\000' then false - else ( - if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || - c >= '0' && c <= '9' then - true - else ( - r := false; - false - ) - ) - ) in - incr i; - b - ); - !r - with - Invalid_argument _ -> false + Bitstring.string_of_bitstring bs let follow_pointer t addr = let ws = get_wordsize t in let e = get_endian t in let bits = bits_of_wordsize ws in let str = get_bytes t addr (bytes_of_wordsize ws) in - let bs = Bitmatch.bitstring_of_string str in + let bs = Bitstring.bitstring_of_string str in bitmatch bs with | { addr : bits : endian (e) } -> addr | { _ } -> invalid_arg "follow_pointer" @@ -320,3 +613,8 @@ let succ_long t addr = let pred_long t addr = let ws = get_wordsize t in addr -^ Int64.of_int (bytes_of_wordsize ws) + +let align t addr = + let ws = get_wordsize t in + let mask = Int64.of_int (bytes_of_wordsize ws - 1) in + (addr +^ mask) &^ (Int64.lognot mask)