bf048f793565131f397843d6478096570552ca63
[virt-mem.git] / lib / virt_mem_mmap.ml
1 (* Memory info command for virtual domains.
2    (C) Copyright 2008 Richard W.M. Jones, Red Hat Inc.
3    http://libvirt.org/
4
5    This program is free software; you can redistribute it and/or modify
6    it under the terms of the GNU General Public License as published by
7    the Free Software Foundation; either version 2 of the License, or
8    (at your option) any later version.
9
10    This program is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU General Public License for more details.
14
15    You should have received a copy of the GNU General Public License
16    along with this program; if not, write to the Free Software
17    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
19    Functions for making a memory map of a virtual machine from
20    various sources.  The memory map will most certainly have holes.
21  *)
22
23 open Unix
24 open Printf
25 open Bigarray
26
27 open Virt_mem_utils
28
29 let debug = false
30
31 (* An address. *)
32 type addr = int64
33
34 (* A range of addresses (start and start+size). *)
35 type interval = addr * addr
36
37 (* A mapping. *)
38 type mapping = {
39   start : addr;
40   size : addr;
41   (* Bigarray mmap(2)'d region with byte addressing: *)
42   arr : (char,int8_unsigned_elt,c_layout) Array1.t;
43   (* The order that the mappings were added, 0 for the first mapping,
44    * 1 for the second mapping, etc.
45    *)
46   order : int;
47 }
48
49 (* A memory map. *)
50 type ('ws,'e,'hm) t = {
51   (* List of mappings, kept in reverse order they were added (new
52    * mappings are added at the head of this list).
53    *)
54   mappings : mapping list;
55
56   (* Segment tree for fast access to a mapping at a particular address.
57    * This is rebuilt each time a new mapping is added.
58    * NB! If mappings = [], ignore contents of this field.  (This is
59    * enforced by the 'hm phantom type).
60    *)
61   tree : (interval * mapping option, interval * mapping option) binary_tree;
62
63   (* Word size, endianness.
64    * Phantom types enforce that these are set before being used.
65    *)
66   wordsize : wordsize;
67   endian : Bitstring.endian;
68 }
69
70 let create () = {
71   mappings = [];
72   tree = Leaf ((0L,0L),None);
73   wordsize = W32;
74   endian = Bitstring.LittleEndian;
75 }
76
77 let set_wordsize t ws = { t with wordsize = ws }
78
79 let set_endian t e = { t with endian = e }
80
81 let get_wordsize t = t.wordsize
82
83 let get_endian t = t.endian
84
85 (* Build the segment tree from the list of mappings.  This code
86  * is taken from virt-df.  For an explanation of the process see:
87  * http://en.wikipedia.org/wiki/Segment_tree
88  *
89  * See also the 'get_mapping' function below which uses this tree
90  * to do fast lookups.
91  *)
92 let tree_of_mappings mappings =
93   (* Construct the list of distinct endpoints. *)
94   let eps =
95     List.map
96       (fun { start = start; size = size } -> [start; start +^ size])
97       mappings in
98   let eps = sort_uniq (List.concat eps) in
99
100   (* Construct the elementary intervals. *)
101   let elints =
102     let elints, lastpoint =
103       List.fold_left (
104         fun (elints, prevpoint) point ->
105           ((point, point) :: (prevpoint, point) :: elints), point
106       ) ([], 0L) eps in
107     let elints = (lastpoint, Int64.max_int(*XXX*)) :: elints in
108     List.rev elints in
109
110   if debug then (
111     eprintf "elementary intervals (%d in total):\n" (List.length elints);
112     List.iter (
113       fun (startpoint, endpoint) ->
114         eprintf "  %Lx %Lx\n" startpoint endpoint
115     ) elints
116   );
117
118   (* Construct the binary tree of elementary intervals. *)
119   let tree =
120     (* Each elementary interval becomes a leaf. *)
121     let elints = List.map (fun elint -> Leaf elint) elints in
122     (* Recursively build this into a binary tree. *)
123     let rec make_layer = function
124       | [] -> []
125       | ([_] as x) -> x
126       (* Turn pairs of leaves at the bottom level into nodes. *)
127       | (Leaf _ as a) :: (Leaf _ as b) :: xs ->
128           let xs = make_layer xs in
129           Node (a, (), b) :: xs
130       (* Turn pairs of nodes at higher levels into nodes. *)
131       | (Node _ as left) :: ((Node _|Leaf _) as right) :: xs ->
132           let xs = make_layer xs in
133           Node (left, (), right) :: xs
134       | Leaf _ :: _ -> assert false (* never happens??? (I think) *)
135     in
136     let rec loop = function
137       | [] -> assert false
138       | [x] -> x
139       | xs -> loop (make_layer xs)
140     in
141     loop elints in
142
143   if debug then (
144     let leaf_printer (startpoint, endpoint) =
145       sprintf "%Lx-%Lx" startpoint endpoint
146     in
147     let node_printer () = "" in
148     print_binary_tree leaf_printer node_printer tree
149   );
150
151   (* Insert the mappings into the tree one by one. *)
152   let tree =
153     (* For each node/leaf in the tree, add its interval and an
154      * empty list which will be used to store the mappings.
155      *)
156     let rec interval_tree = function
157       | Leaf elint -> Leaf (elint, None)
158       | Node (left, (), right) ->
159           let left = interval_tree left in
160           let right = interval_tree right in
161           let (leftstart, _) = interval_of_node left in
162           let (_, rightend) = interval_of_node right in
163           let interval = leftstart, rightend in
164           Node (left, (interval, None), right)
165     and interval_of_node = function
166       | Leaf (elint, _) -> elint
167       | Node (_, (interval, _), _) -> interval
168     in
169
170     let tree = interval_tree tree in
171     (* This should always be true: *)
172     assert (interval_of_node tree = (0L, Int64.max_int(*XXX*)));
173
174     (* "Contained in" operator.
175      * 'a <-< b' iff 'a' is a subinterval of 'b'.
176      *      |<---- a ---->|
177      * |<----------- b ----------->|
178      *)
179     let (<-<) (a1, a2) (b1, b2) = b1 <= a1 && a2 <= b2 in
180
181     (* "Intersects" operator.
182      * 'a /\ b' iff intervals 'a' and 'b' overlap, eg:
183      *      |<---- a ---->|
184      *                |<----------- b ----------->|
185      *)
186     let ( /\ ) (a1, a2) (b1, b2) = a2 > b1 || b2 > a1 in
187
188     let rec insert_mapping tree mapping =
189       let { start = start; size = size } = mapping in
190       let seginterval = start, start +^ size in
191
192       match tree with
193       (* Test if we should insert into this leaf or node: *)
194       | Leaf (interval, None) when interval <-< seginterval ->
195           Leaf (interval, Some mapping)
196       | Leaf (interval, Some oldmapping) when interval <-< seginterval ->
197           let mapping =
198             if oldmapping.order > mapping.order then oldmapping else mapping in
199           Leaf (interval, Some mapping)
200
201       | Node (left, (interval, None), right) when interval <-< seginterval ->
202           Node (left, (interval, Some mapping), right)
203
204       | Node (left, (interval, Some oldmapping), right)
205           when interval <-< seginterval ->
206           let mapping =
207             if oldmapping.order > mapping.order then oldmapping else mapping in
208           Node (left, (interval, Some mapping), right)
209
210       | (Leaf _) as leaf -> leaf
211
212       (* Else, should we insert into left or right subtrees? *)
213       | Node (left, i, right) ->
214           let left =
215             if seginterval /\ interval_of_node left then
216               insert_mapping left mapping
217             else
218               left in
219           let right =
220             if seginterval /\ interval_of_node right then
221               insert_mapping right mapping
222             else
223               right in
224           Node (left, i, right)
225     in
226     let tree = List.fold_left insert_mapping tree mappings in
227     tree in
228
229   if debug then (
230     let printer ((sp, ep), mapping) =
231       sprintf "[%Lx-%Lx] " sp ep ^
232         match mapping with
233         | None -> "(none)"
234         | Some { start = start; size = size; order = order } ->
235             sprintf "%Lx..%Lx(%d)" start (start+^size-^1L) order
236     in
237     print_binary_tree printer printer tree
238   );
239
240   tree
241
242 let add_mapping ({ mappings = mappings } as t) start size arr =
243   let order = List.length mappings in
244   let mapping = { start = start; size = size; arr = arr; order = order } in
245   let mappings = mapping :: mappings in
246   let tree = tree_of_mappings mappings in
247   { t with mappings = mappings; tree = tree }
248
249 let add_file t fd addr =
250   let size = (fstat fd).st_size in
251   (* mmap(2) the file using Bigarray module. *)
252   let arr = Array1.map_file fd char c_layout false size in
253   (* Create the mapping entry. *)
254   add_mapping t addr (Int64.of_int size) arr
255
256 let add_string ({ mappings = mappings } as t) str addr =
257   let size = String.length str in
258   (* Copy the string data to a Bigarray. *)
259   let arr = Array1.create char c_layout size in
260   for i = 0 to size-1 do
261     Array1.set arr i (String.unsafe_get str i)
262   done;
263   (* Create the mapping entry. *)
264   add_mapping t addr (Int64.of_int size) arr
265
266 let of_file fd addr =
267   let t = create () in
268   add_file t fd addr
269
270 let of_string str addr =
271   let t = create () in
272   add_string t str addr
273
274 (* 'get_mapping' is the crucial, fast lookup function for address -> mapping.
275  * It searches the tree (hence fast) to work out the topmost mapping which
276  * applies to an address.
277  *
278  * Returns (rightend * mapping option)
279  * where 'mapping option' is the mapping (or None if it's a hole)
280  *   and 'rightend' is the next address at which there is a different
281  *       mapping/hole.  In other words, this mapping result is good for
282  *       addresses [addr .. rightend-1].
283  *)
284 let rec get_mapping addr = function
285   | Leaf ((_, rightend), mapping) -> rightend, mapping
286
287   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
288           (_, None),
289           right) ->
290       let subrightend, submapping =
291         if addr < leftend then get_mapping addr left
292         else get_mapping addr right in
293       subrightend, submapping
294
295   | Node ((Leaf ((_, leftend), _) | Node (_, ((_, leftend), _), _) as left),
296           (_, Some mapping),
297           right) ->
298       let subrightend, submapping =
299         if addr < leftend then get_mapping addr left
300         else get_mapping addr right in
301       (match submapping with
302        | None -> subrightend, Some mapping
303        | Some submapping ->
304            subrightend,
305            Some (if mapping.order > submapping.order then mapping
306                  else submapping)
307       )
308
309 (* Use the tree to quickly check if an address is mapped (returns false
310  * if it's a hole).
311  *)
312 let is_mapped { tree = tree } addr =
313   let _, mapping = get_mapping addr tree in
314   mapping <> None
315
316 (* Get a single byte. *)
317 let get_byte { tree = tree } addr =
318   (* Get the mapping which applies to this address: *)
319   let _, mapping = get_mapping addr tree in
320   match mapping with
321   | Some { start = start; size = size; arr = arr } ->
322       let offset = Int64.to_int (addr -^ start) in
323       Char.code (Array1.get arr offset)
324   | None ->
325       invalid_arg "get_byte"
326
327 (* Get a range of bytes, possibly across several intervals. *)
328 let get_bytes { tree = tree } addr len =
329   let str = String.create len in
330
331   let rec loop addr pos len =
332     if len > 0 then (
333       let rightend, mapping = get_mapping addr tree in
334       match mapping with
335       | Some { start = start; size = size; arr = arr } ->
336           (* Offset within this mapping. *)
337           let offset = Int64.to_int (addr -^ start) in
338           (* Number of bytes to read before we either get to the end
339            * of our 'len' or we fall off the end of this interval.
340            *)
341           let n = min len (Int64.to_int (rightend -^ addr)) in
342           for i = 0 to n-1 do
343             String.unsafe_set str (pos + i) (Array1.get arr (offset + i))
344           done;
345           let len = len - n in
346           loop (addr +^ Int64.of_int n) (pos + n) len
347
348       | None ->
349           invalid_arg "get_bytes"
350     )
351   in
352   loop addr 0 len;
353
354   str
355
356 let get_int32 t addr =
357   let e = get_endian t in
358   let str = get_bytes t addr 4 in
359   let bs = Bitstring.bitstring_of_string str in
360   bitmatch bs with
361   | { addr : 32 : endian (e) } -> addr
362   | { _ } -> invalid_arg "get_int32"
363
364 let get_int64 t addr =
365   let e = get_endian t in
366   let str = get_bytes t addr 8 in
367   let bs = Bitstring.bitstring_of_string str in
368   bitmatch bs with
369   | { addr : 64 : endian (e) } -> addr
370   | { _ } -> invalid_arg "get_int64"
371
372 let get_C_int = get_int32
373
374 let get_C_long t addr =
375   let ws = get_wordsize t in
376   match ws with
377   | W32 -> Int64.of_int32 (get_int32 t addr)
378   | W64 -> get_int64 t addr
379
380 (* Take bytes until a condition is not met.  This is efficient
381  * in that we stay within the same mapping as long as we can.
382  *
383  * If we hit a hole, raises Invalid_argument "dowhile".
384  *)
385 let dowhile { tree = tree } addr cond =
386   let rec loop addr =
387     let rightend, mapping = get_mapping addr tree in
388     match mapping with
389     | Some { start = start; size = size; arr = arr } ->
390         (* Offset within this mapping. *)
391         let offset = Int64.to_int (addr -^ start) in
392         (* Number of bytes before we fall off the end of this interval. *)
393         let n = Int64.to_int (rightend -^ addr) in
394
395         let rec loop2 addr offset n =
396           if n > 0 then (
397             let c = Array1.get arr offset in
398             if cond addr c then
399               loop2 (addr +^ 1L) (offset + 1) (n - 1)
400             else
401               false (* stop now, finish outer loop too *)
402           )
403           else true (* fell off the end, so continue outer loop *)
404         in
405         if loop2 addr offset n then
406           loop (addr +^ Int64.of_int n)
407
408     | None ->
409         invalid_arg "dowhile"
410   in
411   loop addr
412
413 (* Get a string, ending at ASCII NUL character. *)
414 let get_string t addr =
415   let chars = ref [] in
416   try
417     dowhile t addr (
418       fun _ c ->
419         if c <> '\000' then (
420           chars := c :: !chars;
421           true
422         ) else false
423     );
424     let chars = List.rev !chars in
425     let len = List.length chars in
426     let str = String.create len in
427     let i = ref 0 in
428     List.iter (fun c -> String.unsafe_set str !i c; incr i) chars;
429     str
430   with
431     Invalid_argument _ -> invalid_arg "get_string"
432
433 let is_string t addr =
434   try dowhile t addr (fun _ c -> c <> '\000'); true
435   with Invalid_argument _ -> false
436
437 let is_C_identifier t addr =
438   let i = ref 0 in
439   let r = ref true in
440   try
441     dowhile t addr (
442       fun _ c ->
443         let b =
444           if !i = 0 then (
445             c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z'
446           ) else (
447             if c = '\000' then false
448             else (
449               if c = '_' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' ||
450                 c >= '0' && c <= '9' then
451                   true
452               else (
453                 r := false;
454                 false
455               )
456             )
457           ) in
458         incr i;
459         b
460     );
461     !r
462   with
463     Invalid_argument _ -> false
464
465 (* The following functions are actually written in C
466  * because memmem(3) is likely to be much faster than anything
467  * we could write in OCaml.
468  *
469  * Also OCaml bigarrays are specifically designed to be accessed
470  * easily from C:
471  *   http://caml.inria.fr/pub/docs/manual-ocaml/manual043.html
472  *)
473 (*
474 (* Array+offset = string? *)
475 let string_at arr offset str strlen =
476   let j = ref offset in
477   let rec loop i =
478     if i >= strlen then true
479     else
480       if Array1.get arr !j <> str.[i] then false
481       else (
482         incr j;
483         loop (i+1)
484       )
485   in
486   loop 0
487
488 (* Find in a single file mapping.
489  * [start] is relative to the mapping and we return an offset relative
490  * to the mapping.
491  *)
492 let _find_in start align str arr =
493   let strlen = String.length str in
494   if strlen > 0 then (
495     let j = ref start in
496     let e = Array1.dim arr - strlen in
497     let rec loop () =
498       if !j <= e then (
499         if string_at arr !j str strlen then Some !j
500         else (
501           j := !j + align;
502           loop ()
503         )
504       )
505       else None
506     in
507     loop ()
508   )
509   else Some start
510 *)
511 external _find_in :
512   int -> int -> string -> (char,int8_unsigned_elt,c_layout) Array1.t ->
513   int option = "virt_mem_mmap_find_in"
514
515 (* Generic find function. *)
516 let _find { tree = tree } start align str =
517   let rec loop addr =
518     let rightend, mapping = get_mapping addr tree in
519     match mapping with
520     | Some { start = start; size = size; arr = arr } ->
521         (* Offset within this mapping. *)
522         let offset = Int64.to_int (addr -^ start) in
523
524         (match _find_in offset align str arr with
525         | None -> None
526         | Some offset -> Some (start +^ Int64.of_int offset)
527         )
528
529     | None ->
530         (* Find functions all silently skip holes, so: *)
531         loop rightend
532   in
533   loop start
534
535 let find t ?(start=0L) str =
536   _find t start 1 str
537
538 let find_align t ?(start=0L) str =
539   let align = bytes_of_wordsize (get_wordsize t) in
540   _find t start align str
541
542 let rec _find_all t start align str =
543   match _find t start align str with
544   | None -> []
545   | Some offset ->
546       offset :: _find_all t (offset +^ Int64.of_int align) align str
547
548 let find_all t ?(start=0L) str =
549   _find_all t start 1 str
550
551 let find_all_align t ?(start=0L) str =
552   let align = bytes_of_wordsize (get_wordsize t) in
553   _find_all t start align str
554
555 (* NB: Phantom types in the interface ensure that these pointer functions
556  * can only be called once endianness and wordsize have both been set.
557  *)
558
559 let rec find_pointer t ?start addr =
560   find_align t ?start (string_of_addr t addr)
561
562 and find_pointer_all t ?start addr =
563   find_all_align t ?start (string_of_addr t addr)
564
565 (*
566 and string_of_addr t addr =
567   let bits = bits_of_wordsize (get_wordsize t) in
568   let e = get_endian t in
569   let bs = BITSTRING { addr : bits : endian (e) } in
570   Bitstring.string_of_bitstring bs
571 *)
572 (* XXX bitstring is missing 'construct_int64_le_unsigned' so we
573  * have to force this to 32 bits for the moment.
574  *)
575 and string_of_addr t addr =
576   let bits = bits_of_wordsize (get_wordsize t) in
577   assert (bits = 32);
578   let e = get_endian t in
579   let bs = BITSTRING { Int64.to_int32 addr : 32 : endian (e) } in
580   Bitstring.string_of_bitstring bs
581
582 let follow_pointer t addr =
583   let ws = get_wordsize t in
584   let e = get_endian t in
585   let bits = bits_of_wordsize ws in
586   let str = get_bytes t addr (bytes_of_wordsize ws) in
587   let bs = Bitstring.bitstring_of_string str in
588   bitmatch bs with
589   | { addr : bits : endian (e) } -> addr
590   | { _ } -> invalid_arg "follow_pointer"
591
592 let succ_long t addr =
593   let ws = get_wordsize t in
594   addr +^ Int64.of_int (bytes_of_wordsize ws)
595
596 let pred_long t addr =
597   let ws = get_wordsize t in
598   addr -^ Int64.of_int (bytes_of_wordsize ws)
599
600 let align t addr =
601   let ws = get_wordsize t in
602   let mask = Int64.of_int (bytes_of_wordsize ws - 1) in
603   (addr +^ mask) &^ (Int64.lognot mask)