+(* Build the segment tree from the list of mappings. This code
+ * is taken from virt-df. For an explanation of the process see:
+ * http://en.wikipedia.org/wiki/Segment_tree
+ *)
+let tree_of_mappings mappings =
+ (* Construct the list of distinct endpoints. *)
+ let eps =
+ List.map
+ (fun { start = start; size = size } -> [start; start +^ size])
+ mappings in
+ let eps = sort_uniq (List.concat eps) in
+
+ (* Construct the elementary intervals. *)
+ let elints =
+ let elints, lastpoint =
+ List.fold_left (
+ fun (elints, prevpoint) point ->
+ ((point, point) :: (prevpoint, point) :: elints), point
+ ) ([], 0L) eps in
+ let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
+ List.rev elints in
+
+ if debug then (
+ eprintf "elementary intervals (%d in total):\n" (List.length elints);
+ List.iter (
+ fun (startpoint, endpoint) ->
+ eprintf " %Lx %Lx\n" startpoint endpoint
+ ) elints
+ );
+
+ (* Construct the binary tree of elementary intervals. *)
+ let tree =
+ (* Each elementary interval becomes a leaf. *)
+ let elints = List.map (fun elint -> Leaf elint) elints in
+ (* Recursively build this into a binary tree. *)
+ let rec make_layer = function
+ | [] -> []
+ | ([_] as x) -> x
+ (* Turn pairs of leaves at the bottom level into nodes. *)
+ | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
+ let xs = make_layer xs in
+ Node (a, (), b) :: xs
+ (* Turn pairs of nodes at higher levels into nodes. *)
+ | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
+ let xs = make_layer xs in
+ Node (left, (), right) :: xs
+ | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
+ in
+ let rec loop = function
+ | [] -> assert false
+ | [x] -> x
+ | xs -> loop (make_layer xs)
+ in
+ loop elints in
+
+ if debug then (
+ let leaf_printer (startpoint, endpoint) =
+ sprintf "%Lx-%Lx" startpoint endpoint
+ in
+ let node_printer () = "" in
+ print_binary_tree leaf_printer node_printer tree
+ );
+
+ (* Insert the mappings into the tree one by one. *)
+ let tree =
+ (* For each node/leaf in the tree, add its interval and an
+ * empty list which will be used to store the mappings.
+ *)
+ let rec interval_tree = function
+ | Leaf elint -> Leaf (elint, [])
+ | Node (left, (), right) ->
+ let left = interval_tree left in
+ let right = interval_tree right in
+ let (leftstart, _) = interval_of_node left in
+ let (_, rightend) = interval_of_node right in
+ let interval = leftstart, rightend in
+ Node (left, (interval, []), right)
+ and interval_of_node = function
+ | Leaf (elint, _) -> elint
+ | Node (_, (interval, _), _) -> interval
+ in
+
+ let tree = interval_tree tree in
+ (* This should always be true: *)
+ assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
+
+ (* "Contained in" operator.
+ * 'a <-< b' iff 'a' is a subinterval of 'b'.
+ * |<---- a ---->|
+ * |<----------- b ----------->|
+ *)
+ let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
+
+ (* "Intersects" operator.
+ * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
+ * |<---- a ---->|
+ * |<----------- b ----------->|
+ *)
+ let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
+
+ let rec insert_mapping tree mapping =
+ let { start = start; size = size } = mapping in
+ let seginterval = start, start +^ size in
+
+ match tree with
+ (* Test if we should insert into this leaf or node: *)
+ | Leaf (interval, mappings) when interval <-< seginterval ->
+ Leaf (interval, mapping :: mappings)
+ | Node (left, (interval, mappings), right)
+ when interval <-< seginterval ->
+ Node (left, (interval, mapping :: mappings), right)
+
+ | (Leaf _) as leaf -> leaf
+
+ (* Else, should we insert into left or right subtrees? *)
+ | Node (left, i, right) ->
+ let left =
+ if seginterval /\ interval_of_node left then
+ insert_mapping left mapping
+ else
+ left in
+ let right =
+ if seginterval /\ interval_of_node right then
+ insert_mapping right mapping
+ else
+ right in
+ Node (left, i, right)
+ in
+ let tree = List.fold_left insert_mapping tree mappings in
+ tree in
+
+ if debug then (
+ let printer ((sp, ep), mappings) =
+ sprintf "[%Lx-%Lx] " sp ep ^
+ String.concat ";"
+ (List.map (fun { start = start; size = size } ->
+ sprintf "%Lx+%Lx" start size)
+ mappings)
+ in
+ print_binary_tree printer printer tree
+ );