Revised Virt_mem_mmap handling overlapping mappings efficiently.
[virt-mem.git] / lib / virt_mem_mmap.ml
1 (* Memory info command for virtual domains.
2    (C) Copyright 2008 Richard W.M. Jones, Red Hat Inc.
3    http://libvirt.org/
4
5    This program is free software; you can redistribute it and/or modify
6    it under the terms of the GNU General Public License as published by
7    the Free Software Foundation; either version 2 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU General Public License for more details.
14
15    You should have received a copy of the GNU General Public License
16    along with this program; if not, write to the Free Software
17    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
19    Functions for making a memory map of a virtual machine from
20    various sources.  The memory map will most certainly have holes.
21  *)
22
23 open Unix
24 open Printf
25 open Bigarray
26
27 open Virt_mem_utils
28
29 let debug = true
30
31 (* An address. *)
32 type addr = int64
33
34 (* A range of addresses (start and start+size). *)
35 type interval = addr * addr
36
37 (* A mapping. *)
38 type mapping = {
39   start : addr;
40   size : addr;
41   (* Bigarray mmap(2)'d region with byte addressing: *)
42   arr : (char,int8_unsigned_elt,c_layout) Array1.t;
43   (* The order that the mappings were added, 0 for the first mapping,
44    * 1 for the second mapping, etc.
45    *)
46   order : int;
47 }
48
49 (* A memory map. *)
50 type ('ws,'e,'hm) t = {
51   (* List of mappings, kept in reverse order they were added (new
52    * mappings are added at the head of this list).
53    *)
54   mappings : mapping list;
55
56   (* Segment tree for fast access to a mapping at a particular address.
57    * This is rebuilt each time a new mapping is added.
58    * NB! If mappings = [], ignore contents of this field.  (This is
59    * enforced by the 'hm phantom type).
60    *)
61   tree : (interval * mapping list, interval * mapping list) binary_tree;
62
63   (* Word size, endianness.
64    * Phantom types enforce that these are set before being used.
65    *)
66   wordsize : wordsize;
67   endian : Bitstring.endian;
68 }
69
70 let create () = {
71   mappings = [];
72   tree = Leaf ((0L,0L),[]);
73   wordsize = W32;
74   endian = Bitstring.LittleEndian;
75 }
76
77 let set_wordsize t ws = { t with wordsize = ws }
78
79 let set_endian t e = { t with endian = e }
80
81 let get_wordsize t = t.wordsize
82
83 let get_endian t = t.endian
84
85 (* Build the segment tree from the list of mappings.  This code
86  * is taken from virt-df.  For an explanation of the process see:
87  * http://en.wikipedia.org/wiki/Segment_tree
88  *)
89 let tree_of_mappings mappings =
90   (* Construct the list of distinct endpoints. *)
91   let eps =
92     List.map
93       (fun { start = start; size = size } -> [start; start +^ size])
94       mappings in
95   let eps = sort_uniq (List.concat eps) in
96
97   (* Construct the elementary intervals. *)
98   let elints =
99     let elints, lastpoint =
100       List.fold_left (
101         fun (elints, prevpoint) point ->
102           ((point, point) :: (prevpoint, point) :: elints), point
103       ) ([], 0L) eps in
104     let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
105     List.rev elints in
106
107   if debug then (
108     eprintf "elementary intervals (%d in total):\n" (List.length elints);
109     List.iter (
110       fun (startpoint, endpoint) ->
111         eprintf "  %Lx %Lx\n" startpoint endpoint
112     ) elints
113   );
114
115   (* Construct the binary tree of elementary intervals. *)
116   let tree =
117     (* Each elementary interval becomes a leaf. *)
118     let elints = List.map (fun elint -> Leaf elint) elints in
119     (* Recursively build this into a binary tree. *)
120     let rec make_layer = function
121       | [] -> []
122       | ([_] as x) -> x
123       (* Turn pairs of leaves at the bottom level into nodes. *)
124       | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
125           let xs = make_layer xs in
126           Node (a, (), b) :: xs
127       (* Turn pairs of nodes at higher levels into nodes. *)
128       | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
129           let xs = make_layer xs in
130           Node (left, (), right) :: xs
131       | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
132     in
133     let rec loop = function
134       | [] -> assert false
135       | [x] -> x
136       | xs -> loop (make_layer xs)
137     in
138     loop elints in
139
140   if debug then (
141     let leaf_printer (startpoint, endpoint) =
142       sprintf "%Lx-%Lx" startpoint endpoint
143     in
144     let node_printer () = "" in
145     print_binary_tree leaf_printer node_printer tree
146   );
147
148   (* Insert the mappings into the tree one by one. *)
149   let tree =
150     (* For each node/leaf in the tree, add its interval and an
151      * empty list which will be used to store the mappings.
152      *)
153     let rec interval_tree = function
154       | Leaf elint -> Leaf (elint, [])
155       | Node (left, (), right) ->
156           let left = interval_tree left in
157           let right = interval_tree right in
158           let (leftstart, _) = interval_of_node left in
159           let (_, rightend) = interval_of_node right in
160           let interval = leftstart, rightend in
161           Node (left, (interval, []), right)
162     and interval_of_node = function
163       | Leaf (elint, _) -> elint
164       | Node (_, (interval, _), _) -> interval
165     in
166
167     let tree = interval_tree tree in
168     (* This should always be true: *)
169     assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
170
171     (* "Contained in" operator.
172      * 'a <-< b' iff 'a' is a subinterval of 'b'.
173      *      |<---- a ---->|
174      * |<----------- b ----------->|
175      *)
176     let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
177
178     (* "Intersects" operator.
179      * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
180      *      |<---- a ---->|
181      *                |<----------- b ----------->|
182      *)
183     let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
184
185     let rec insert_mapping tree mapping =
186       let { start = start; size = size } = mapping in
187       let seginterval = start, start +^ size in
188
189       match tree with
190       (* Test if we should insert into this leaf or node: *)
191       | Leaf (interval, mappings) when interval <-< seginterval ->
192           Leaf (interval, mapping :: mappings)
193       | Node (left, (interval, mappings), right)
194           when interval <-< seginterval ->
195           Node (left, (interval, mapping :: mappings), right)
196
197       | (Leaf _) as leaf -> leaf
198
199       (* Else, should we insert into left or right subtrees? *)
200       | Node (left, i, right) ->
201           let left =
202             if seginterval /\ interval_of_node left then
203               insert_mapping left mapping
204             else
205               left in
206           let right =
207             if seginterval /\ interval_of_node right then
208               insert_mapping right mapping
209             else
210               right in
211           Node (left, i, right)
212     in
213     let tree = List.fold_left insert_mapping tree mappings in
214     tree in
215
216   if debug then (
217     let printer ((sp, ep), mappings) =
218       sprintf "[%Lx-%Lx] " sp ep ^
219         String.concat ";"
220         (List.map (fun { start = start; size = size } ->
221                      sprintf "%Lx+%Lx" start size)
222            mappings)
223     in
224     print_binary_tree printer printer tree
225   );
226
227   tree
228
229 let add_mapping ({ mappings = mappings } as t) start size arr =
230   let order = List.length mappings in
231   let mapping = { start = start; size = size; arr = arr; order = order } in
232   let mappings = mapping :: mappings in
233   let tree = tree_of_mappings mappings in
234   { t with mappings = mappings; tree = tree }
235
236 let add_file t fd addr =
237   let size = (fstat fd).st_size in
238   (* mmap(2) the file using Bigarray module. *)
239   let arr = Array1.map_file fd char c_layout false size in
240   (* Create the mapping entry. *)
241   add_mapping t addr (Int64.of_int size) arr
242
243 let add_string ({ mappings = mappings } as t) str addr =
244   let size = String.length str in
245   (* Copy the string data to a Bigarray. *)
246   let arr = Array1.create char c_layout size in
247   for i = 0 to size-1 do
248     Array1.set arr i (String.unsafe_get str i)
249   done;
250   (* Create the mapping entry. *)
251   add_mapping t addr (Int64.of_int size) arr
252
253 let of_file fd addr =
254   let t = create () in
255   add_file t fd addr
256
257 let of_string str addr =
258   let t = create () in
259   add_string t str addr
260
261 (* Look up an address and get the top-most mapping which contains it.
262  * This uses the segment tree, so it's fast.  The top-most mapping is
263  * the one with the highest 'order' field.
264  *
265  * Warning: This 'hot' code was carefully optimized based on
266  * feedback from 'gprof'.  Avoid fiddling with it.
267  *)
268 let rec get_mapping addr = function
269   | Leaf (_, []) -> None
270   | Leaf (_, [mapping]) -> Some mapping
271   | Leaf (_, mappings) -> Some (find_highest_order mappings)
272
273   (* Try to avoid expensive search if node mappings is empty: *)
274   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
275           (_, []),
276           right) ->
277       let submapping =
278         if addr < leftend then get_mapping addr left
279         else get_mapping addr right in
280       submapping
281
282   (* ... or a singleton: *)
283   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
284           (_, [mapping]),
285           right) ->
286       let submapping =
287         if addr < leftend then get_mapping addr left
288         else get_mapping addr right in
289       (match submapping with
290        | None -> Some mapping
291        | Some submapping ->
292            Some (if mapping.order > submapping.order then mapping
293                  else submapping)
294       )
295
296   (* Normal recursive case: *)
297   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
298           (_, mappings),
299           right) ->
300       let submapping =
301         if addr < leftend then get_mapping addr left
302         else get_mapping addr right in
303       (match submapping with
304        | None -> Some (find_highest_order mappings)
305        | Some submapping -> Some (find_highest_order (submapping :: mappings))
306       )
307
308 and find_highest_order mappings =
309   List.fold_left (
310     fun mapping1 mapping2 ->
311       if mapping1.order > mapping2.order then mapping1 else mapping2
312   ) (List.hd mappings) (List.tl mappings)
313
314 (* Get a single byte. *)
315 let get_byte { tree = tree } addr =
316   (* Get the mapping which applies to this address: *)
317   match get_mapping addr tree with
318   | Some { start = start; size = size; arr = arr } ->
319       let offset = Int64.to_int (addr -^ start) in
320       Char.code (Array1.get arr offset)
321   | None ->
322       invalid_arg "get_byte"
323 (*
324   let rec loop = function
325     | [] -> invalid_arg "get_byte"
326     | { start = start; size = size; arr = arr } :: _
327         when start <= addr && addr < start +^ size ->
328         let offset = Int64.to_int (addr -^ start) in
329         Char.code (Array1.get arr offset)
330     | _ :: ms -> loop ms
331   in
332   loop mappings
333 *)
334
335
336 (*
337
338 (* Find in mappings and return first predicate match. *)
339 let _find_map { mappings = mappings } pred =
340   let rec loop = function
341     | [] -> None
342     | m :: ms ->
343         match pred m with
344         | Some n -> Some n
345         | None -> loop ms
346   in
347   loop mappings
348
349 (* The following functions are actually written in C
350  * because memmem(3) is likely to be much faster than anything
351  * we could write in OCaml.
352  *
353  * Also OCaml bigarrays are specifically designed to be accessed
354  * easily from C:
355  *   http://caml.inria.fr/pub/docs/manual-ocaml/manual043.html
356  *)
357 (*
358 (* Array+offset = string? *)
359 let string_at arr offset str strlen =
360   let j = ref offset in
361   let rec loop i =
362     if i >= strlen then true
363     else
364       if Array1.get arr !j <> str.[i] then false
365       else (
366         incr j;
367         loop (i+1)
368       )
369   in
370   loop 0
371
372 (* Find in a single file mapping.
373  * [start] is relative to the mapping and we return an offset relative
374  * to the mapping.
375  *)
376 let _find_in start align str arr =
377   let strlen = String.length str in
378   if strlen > 0 then (
379     let j = ref start in
380     let e = Array1.dim arr - strlen in
381     let rec loop () =
382       if !j <= e then (
383         if string_at arr !j str strlen then Some !j
384         else (
385           j := !j + align;
386           loop ()
387         )
388       )
389       else None
390     in
391     loop ()
392   )
393   else Some start
394 *)
395 external _find_in :
396   int -> int -> string -> (char,int8_unsigned_elt,c_layout) Array1.t ->
397   int option = "virt_mem_mmap_find_in"
398
399 (* Generic find function. *)
400 let _find t start align str =
401   _find_map t (
402     fun { start = mstart; size = msize; arr = arr } ->
403       if mstart >= start then (
404         (* Check this mapping from the beginning. *)
405         match _find_in 0 align str arr with
406         | Some offset -> Some (mstart +^ Int64.of_int offset)
407         | None -> None
408       )
409       else if mstart < start && start <= mstart+^msize then (
410         (* Check this mapping from somewhere in the middle. *)
411         let offset = Int64.to_int (start -^ mstart) in
412         match _find_in offset align str arr with
413         | Some offset -> Some (mstart +^ Int64.of_int offset)
414         | None -> None
415       )
416       else None
417   )
418
419 let find t ?(start=0L) str =
420   _find t start 1 str
421
422 let find_align t ?(start=0L) str =
423   let align = bytes_of_wordsize (get_wordsize t) in
424   _find t start align str
425
426 let rec _find_all t start align str =
427   match _find t start align str with
428   | None -> []
429   | Some offset ->
430       offset :: _find_all t (offset +^ Int64.of_int align) align str
431
432 let find_all t ?(start=0L) str =
433   _find_all t start 1 str
434
435 let find_all_align t ?(start=0L) str =
436   let align = bytes_of_wordsize (get_wordsize t) in
437   _find_all t start align str
438
439 (* NB: Phantom types in the interface ensure that these pointer functions
440  * can only be called once endianness and wordsize have both been set.
441  *)
442
443 let rec find_pointer t ?start addr =
444   find_align t ?start (string_of_addr t addr)
445
446 and find_pointer_all t ?start addr =
447   find_all_align t ?start (string_of_addr t addr)
448
449 (*
450 and string_of_addr t addr =
451   let bits = bits_of_wordsize (get_wordsize t) in
452   let e = get_endian t in
453   let bs = BITSTRING { addr : bits : endian (e) } in
454   Bitstring.string_of_bitstring bs
455 *)
456 (* XXX bitstring is missing 'construct_int64_le_unsigned' so we
457  * have to force this to 32 bits for the moment.
458  *)
459 and string_of_addr t addr =
460   let bits = bits_of_wordsize (get_wordsize t) in
461   assert (bits = 32);
462   let e = get_endian t in
463   let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
464   Bitstring.string_of_bitstring bs
465
466 and addr_of_string t str =
467   let bits = bits_of_wordsize (get_wordsize t) in
468   let e = get_endian t in
469   let bs = Bitstring.bitstring_of_string str in
470   bitmatch bs with
471   | { addr : bits : endian (e) } -> addr
472   | { _ } -> invalid_arg "addr_of_string"
473
474 (* Take bytes until a condition is not met.  This is efficient in that
475  * we stay within the same mapping as long as we can.
476  *)
477 let dowhile { mappings = mappings } addr cond =
478   let rec get_next_mapping addr = function
479     | [] -> invalid_arg "dowhile"
480     | { start = start; size = size; arr = arr } :: _
481         when start <= addr && addr < start +^ size ->
482         let offset = Int64.to_int (addr -^ start) in
483         let len = Int64.to_int size - offset in
484         arr, offset, len
485     | _ :: ms -> get_next_mapping addr ms
486   in
487   let rec loop addr =
488     let arr, offset, len = get_next_mapping addr mappings in
489     let rec loop2 i =
490       if i < len then (
491         let c = Array1.get arr (offset+i) in
492         if cond c then loop2 (i+1)
493       ) else
494         loop (addr +^ Int64.of_int len)
495     in
496     loop2 0
497   in
498   loop addr
499
500 let get_bytes t addr len =
501   let str = String.create len in
502   let i = ref 0 in
503   try
504     dowhile t addr (
505       fun c ->
506         str.[!i] <- c;
507         incr i;
508         !i < len
509     );
510     str
511   with
512     Invalid_argument _ -> invalid_arg "get_bytes"
513
514 let get_int32 t addr =
515   let e = get_endian t in
516   let str = get_bytes t addr 4 in
517   let bs = Bitstring.bitstring_of_string str in
518   bitmatch bs with
519   | { addr : 32 : endian (e) } -> addr
520   | { _ } -> invalid_arg "follow_pointer"
521
522 let get_int64 t addr =
523   let e = get_endian t in
524   let str = get_bytes t addr 8 in
525   let bs = Bitstring.bitstring_of_string str in
526   bitmatch bs with
527   | { addr : 64 : endian (e) } -> addr
528   | { _ } -> invalid_arg "follow_pointer"
529
530 let get_C_int = get_int32
531
532 let get_C_long t addr =
533   let ws = get_wordsize t in
534   match ws with
535   | W32 -> Int64.of_int32 (get_int32 t addr)
536   | W64 -> get_int64 t addr
537
538 let get_string t addr =
539   let chars = ref [] in
540   try
541     dowhile t addr (
542       fun c ->
543         if c <> '\000' then (
544           chars := c :: !chars;
545           true
546         ) else false
547     );
548     let chars = List.rev !chars in
549     let len = List.length chars in
550     let str = String.create len in
551     let i = ref 0 in
552     List.iter (fun c -> str.[!i] <- c; incr i) chars;
553     str
554   with
555     Invalid_argument _ -> invalid_arg "get_string"
556
557 let is_string t addr =
558   try dowhile t addr (fun c -> c <> '\000'); true
559   with Invalid_argument _ -> false
560
561 let is_C_identifier t addr =
562   let i = ref 0 in
563   let r = ref true in
564   try
565     dowhile t addr (
566       fun c ->
567         let b =
568           if !i = 0 then (
569             c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
570           ) else (
571             if c = '\000' then false
572             else (
573               if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
574                 c >= '0' && c <= '9' then
575                   true
576               else (
577                 r := false;
578                 false
579               )
580             )
581           ) in
582         incr i;
583         b
584     );
585     !r
586   with
587     Invalid_argument _ -> false
588
589 let is_mapped { mappings = mappings } addr =
590   let rec loop = function
591     | [] -> false
592     | { start = start; size = size; arr = arr } :: _
593         when start <= addr && addr < start +^ size -> true
594     | _ :: ms -> loop ms
595   in
596   loop mappings
597
598 let follow_pointer t addr =
599   let ws = get_wordsize t in
600   let e = get_endian t in
601   let bits = bits_of_wordsize ws in
602   let str = get_bytes t addr (bytes_of_wordsize ws) in
603   let bs = Bitstring.bitstring_of_string str in
604   bitmatch bs with
605   | { addr : bits : endian (e) } -> addr
606   | { _ } -> invalid_arg "follow_pointer"
607
608 let succ_long t addr =
609   let ws = get_wordsize t in
610   addr +^ Int64.of_int (bytes_of_wordsize ws)
611
612 let pred_long t addr =
613   let ws = get_wordsize t in
614   addr -^ Int64.of_int (bytes_of_wordsize ws)
615
616 let align t addr =
617   let ws = get_wordsize t in
618   let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
619   (addr +^ mask) &^ (Int64.lognot mask)
620
621 let map { mappings = mappings } f =
622   List.map (fun { start = start; size = size } -> f start size) mappings
623
624 let iter t f =
625   ignore (map t (fun start size -> let () = f start size in ()))
626
627 let nr_mappings { mappings = mappings } = List.length mappings
628 *)