Experimental automated 'follower' code.
[virt-mem.git] / extract / codegen / code_generation.ml
index eaec4c8..7b0262b 100644 (file)
@@ -74,6 +74,12 @@ let concat_sig_items items =
   | x :: xs ->
       List.fold_left (fun xs x -> <:sig_item< $xs$ $x$ >>) x xs
 
+let concat_exprs exprs =
+  match exprs with
+  | [] -> assert false
+  | x :: xs ->
+      List.fold_left (fun xs x -> <:expr< $xs$ ; $x$ >>) x xs
+
 let concat_record_fields fields =
   match fields with
     | [] -> assert false
@@ -96,6 +102,13 @@ let build_tuple_from_exprs exprs =
       Ast.ExTup (_loc,
                 List.fold_left (fun xs x -> Ast.ExCom (_loc, x, xs)) x xs)
 
+let build_tuple_from_patts patts =
+  match patts with
+  | [] | [_] -> assert false
+  | x :: xs ->
+      Ast.PaTup (_loc,
+                List.fold_left (fun xs x -> Ast.PaCom (_loc, x, xs)) x xs)
+
 type code = Ast.str_item * Ast.sig_item
 
 let ocaml_type_of_field_type = function
@@ -112,7 +125,7 @@ let generate_types xs =
        fun { SC.sf_name = sf_name; sf_fields = fields } ->
          if fields <> [] then (
            let fields = List.map (
-             fun { PP.field_name = name; PP.field_type = t } ->
+             fun (name, t) ->
                let t = ocaml_type_of_field_type t in
                <:ctyp< $lid:sf_name^"_"^name$ : $t$ >>
            ) fields in
@@ -130,7 +143,7 @@ let generate_types xs =
        fun { SC.cf_name = cf_name; cf_fields = fields } ->
          if fields <> [] then (
            let fields = List.map (
-             fun { PP.field_name = name; PP.field_type = t } ->
+             fun (name, t) ->
                let t = ocaml_type_of_field_type t in
                <:ctyp< $lid:cf_name^"_"^name$ : $t$ >>
            ) fields in
@@ -145,24 +158,13 @@ let generate_types xs =
       let cflist = concat_str_items cflist in
 
       <:str_item<
-        type ('a, 'b) $lid:struct_name$ = {
-         $lid:struct_name^"_shape"$ : 'a;
-         $lid:struct_name^"_content"$ : 'b;
-       }
+        type ('a, 'b) $lid:struct_name$ = 'a * 'b ;;
        $sflist$
        $cflist$
       >>
   ) xs in
 
-  let sigs =
-    List.map (
-      fun (struct_name, _, _) ->
-       <:sig_item<
-          type ('a, 'b) $lid:struct_name$
-       >>
-    ) xs in
-
-  concat_str_items strs, concat_sig_items sigs
+  concat_str_items strs, <:sig_item< >>
 
 let generate_offsets xs =
   (* Only need to generate the offset_of_* functions for fields
@@ -232,11 +234,6 @@ let generate_offsets xs =
     ) fields in
 
   let strs = concat_str_items strs in
-  let strs =
-    <:str_item<
-      module StringMap = Map.Make (String) ;;
-      $strs$
-    >> in
 
   strs, <:sig_item< >>
 
@@ -255,12 +252,6 @@ let generate_parsers xs =
     ) xs in
 
   let strs = concat_str_items strs in
-  let strs =
-    <:str_item<
-      let match_err = "failed to match kernel structure" ;;
-      let zero = 0 ;;
-      $strs$
-    >> in
 
   (* The shared parser functions.
    * 
@@ -318,14 +309,22 @@ let generate_parsers xs =
 
          let shape_assignments =
            List.map (
-             fun { PP.field_name = field_name;
-                   field_type = field_type;
-                   field_offset = offset } ->
+             fun (field_name, field_type) ->
+
+               (* Go and look up the field offset in the correct kernel. *)
+               let { PP.field_offset = offset } =
+                 List.find (fun { PP.field_name = name } -> field_name = name)
+                   structure.PP.struct_fields in
 
+               (* Generate assignment code, if necessary we can adjust
+                * the list_head.
+                *)
                match field_type with
                | PP.FListHeadPointer None ->
-                   sprintf "%s_%s = Int64.sub %s %dL"
-                     sf.SC.sf_name field_name field_name offset
+                   sprintf "%s_%s = (if %s <> 0L then Int64.sub %s %dL else %s)"
+                     sf.SC.sf_name field_name
+                     field_name
+                     field_name offset field_name
 
                | PP.FListHeadPointer (Some (other_struct_name,
                                             other_field_name)) ->
@@ -334,13 +333,15 @@ let generate_parsers xs =
                     * offset_of_<struct>_<field> to find it.
                     *)
                    sprintf "%s_%s = (
-                      let offset = offset_of_%s_%s kernel_version in
-                      let offset = Int64.of_int offset in
-                      Int64.sub %s offset
+                      if %s <> 0L then (
+                        let offset = offset_of_%s_%s kernel_version in
+                        let offset = Int64.of_int offset in
+                        Int64.sub %s offset
+                      ) else %s
                     )"
-                     sf.SC.sf_name field_name
+                     sf.SC.sf_name field_name field_name
                      other_struct_name other_field_name
-                     field_name
+                     field_name field_name
                | _ ->
                    sprintf "%s_%s = %s" sf.SC.sf_name field_name field_name
            ) sf.SC.sf_fields in
@@ -352,7 +353,7 @@ let generate_parsers xs =
 
          let content_assignments =
            List.map (
-             fun { PP.field_name = field_name } ->
+             fun (field_name, _) ->
                sprintf "%s_%s = %s" cf.SC.cf_name field_name field_name
            ) cf.SC.cf_fields in
 
@@ -365,15 +366,14 @@ let generate_parsers xs =
            sprintf "
   bitmatch bits with
   | { %s } ->
-      let shape =
+      let s =
       %s in
-      let content =
+      let c =
       %s in
-      { %s_shape = shape; %s_content = content }
+      (s, c)
   | { _ } ->
       raise (Virt_mem_types.ParseError (%S, %S, match_err))"
              patterns shape_assignments content_assignments
-             struct_name struct_name
              struct_name pa_name in
 
          Hashtbl.add subs pa_name code
@@ -382,17 +382,268 @@ let generate_parsers xs =
 
   (strs, <:sig_item< >>), subs
 
-let output_interf ~output_file types offsets parsers =
-  let sigs = concat_sig_items [ types; offsets; parsers ] in
+(* Helper functions to store things in a fixed-length tuple very efficiently.
+ * Note that the tuple length must be >= 2.
+ *)
+type tuple = string list
+
+let tuple_create fields : tuple = fields
+
+(* Generates 'let _, _, resultpatt, _ = tupleexpr in body'. *)
+let tuple_generate_extract fields field resultpatt tupleexpr body =
+  let patts = List.map (
+    fun name -> if name = field then resultpatt else <:patt< _ >>
+  ) fields in
+  let result = build_tuple_from_patts patts in
+  <:expr< let $result$ = $tupleexpr$ in $body$ >>
+
+(* Generates '(fieldexpr1, fieldexpr2, ...)'. *)
+let tuple_generate_construct fieldexprs =
+  build_tuple_from_exprs fieldexprs
+
+type follower_t =
+  | Missing of string | Follower of string | KernelVersion of string
+
+let generate_followers xs =
+  (* Tuple of follower functions, just a list of struct_names. *)
+  let follower_tuple = tuple_create (List.map fst xs) in
+
+  (* A shape-follow function for every structure/shape. *)
+  let strs = List.map (
+    fun (struct_name, (_, sflist, _, _)) ->
+      List.map (
+       fun { SC.sf_name = sf_name; sf_fields = fields } ->
+         let body = List.fold_right (
+           fun (name, typ) body ->
+             let follower_name =
+               match typ with
+               | PP.FListHeadPointer None -> struct_name
+               | PP.FListHeadPointer (Some (struct_name, _)) -> struct_name
+               | PP.FStructPointer struct_name -> struct_name
+               | _ -> assert false in
+             tuple_generate_extract follower_tuple follower_name
+               <:patt< f >> <:expr< followers >>
+               <:expr<
+                 let map =
+                   f load followers map shape.$lid:sf_name^"_"^name$ in $body$
+               >>
+         ) fields <:expr< map >> in
+
+         <:str_item<
+           let $lid:sf_name^"_follower"$ load followers map shape =
+             $body$
+         >>
+      ) sflist
+  ) xs in
+  let strs = List.concat strs in
+
+  (* A follower function for every kernel version / structure.  When this
+   * function is called starting at some known root, it will load every
+   * reachable kernel structure.
+   *)
+  let strs =
+    let common =
+      (* Share as much common code as possible to minimize generated
+       * code size and benefit i-cache.
+       *)
+      <:str_item<
+       let kv_follower kernel_version struct_name total_size
+           parserfn followerfn
+           load followers map addr =
+         if addr <> 0L && not (AddrMap.mem addr map) then (
+           let map = AddrMap.add addr (struct_name, total_size) map in
+           let bits = load struct_name addr total_size in
+           let shape, _ = parserfn kernel_version bits in
+           followerfn load followers map shape
+         )
+         else map
+      >> in
+
+    let fs =
+      List.map (
+       fun (struct_name, (kernels, _, sfhash, pahash)) ->
+         List.map (
+           fun ({ PP.kernel_version = version; kv_i = kv_i },
+                { PP.struct_total_size = total_size }) ->
+             let { SC.pa_name = pa_name } = Hashtbl.find pahash version in
+             let { SC.sf_name = sf_name } = Hashtbl.find sfhash version in
+
+             let fname = sprintf "%s_kv%d_follower" struct_name kv_i in
+
+             <:str_item<
+               let $lid:fname$ =
+                 kv_follower
+                   $str:version$ $str:struct_name$ $`int:total_size$
+                   $lid:pa_name$ $lid:sf_name^"_follower"$
+             >>
+         ) kernels
+      ) xs in
+
+    let strs = strs @ [ common ] @ List.concat fs in
+    strs in
+
+  (* A map from kernel versions to follower functions.
+   *
+   * For each struct, we have a list of kernel versions which contain
+   * that struct.  Some kernels are missing a particular struct, so
+   * that is turned into a ParseError exception.
+   *)
+  let strs =
+    let nr_kernels =
+      List.fold_left max 0
+       (List.map (fun (_, (kernels, _, _, _)) -> List.length kernels) xs) in
+    let nr_structs = List.length xs in
+    let array = Array.make_matrix nr_kernels (nr_structs+1) (Missing "") in
+    List.iteri (
+      fun si (struct_name, _) ->
+       for i = 0 to nr_kernels - 1 do
+         array.(i).(si+1) <- Missing struct_name
+       done
+    ) xs;
+    List.iteri (
+      fun si (struct_name, (kernels, _, _, _)) ->
+       List.iter (
+         fun ({ PP.kernel_version = version; kv_i = kv_i }, _) ->
+           array.(kv_i).(0) <- KernelVersion version;
+           array.(kv_i).(si+1) <-
+             Follower (sprintf "%s_kv%d_follower" struct_name kv_i)
+       ) kernels
+    ) xs;
+
+    let array = Array.map (
+      fun row ->
+       match Array.to_list row with
+       | [] | (Missing _|Follower _) :: _ -> assert false
+       | KernelVersion kernel_version :: followers -> kernel_version, followers
+    ) array in
+
+    let map = List.fold_left (
+      fun map (kernel_version, followers) ->
+       let followers = List.map (
+         function
+         | Follower fname ->
+             <:expr< $lid:fname$ >>
+
+         (* no follower for this kernel/struct combination *)
+         | Missing struct_name ->
+             <:expr<
+               fun _ _ _ _ ->
+                 raise (
+                   Virt_mem_types.ParseError (
+                     $str:struct_name$, "follower_map", struct_missing_err
+                   )
+                 )
+             >>
+         | KernelVersion _ -> assert false
+       ) followers in
+       let followers = tuple_generate_construct followers in
+
+       <:expr< StringMap.add $str:kernel_version$ $followers$ $map$ >>
+    ) <:expr< StringMap.empty >> (Array.to_list array) in
+
+    let str =
+      <:str_item<
+       let follower_map = $map$
+      >> in
+    strs @ [ str ] in
+
+  (* Finally a publicly exposed follower function. *)
+  let strs =
+    let fs =
+      List.map (
+       fun (struct_name, (kernels, _, _, _)) ->
+         let fname = sprintf "%s_follower" struct_name in
+
+         let body =
+           tuple_generate_extract follower_tuple struct_name
+             <:patt< f >> <:expr< followers >>
+             <:expr<
+               f load followers AddrMap.empty addr
+             >> in
+
+         <:str_item<
+           let $lid:fname$ kernel_version load addr =
+             let followers =
+               try StringMap.find kernel_version follower_map
+               with Not_found ->
+                 unknown_kernel_version kernel_version $str:struct_name$ in
+             $body$
+         >>
+      ) xs in
+
+    strs @ fs in
+
+  let sigs =
+    List.map (
+      fun (struct_name, _) ->
+       <:sig_item<
+          val $lid:struct_name^"_follower"$ :
+           kernel_version ->
+           (string -> Virt_mem_mmap.addr -> int -> Bitstring.bitstring) ->
+           Virt_mem_mmap.addr ->
+           (string * int) AddrMap.t
+         >>
+    ) xs in
+
+  concat_str_items strs, concat_sig_items sigs
+
+let output_interf ~output_file types offsets parsers followers =
+  (* Some standard code that appears at the top of the interface file. *)
+  let prologue =
+    <:sig_item<
+      module AddrMap : sig
+       type key = Virt_mem_mmap.addr
+       type 'a t = 'a Map.Make(Int64).t
+       val empty : 'a t
+       val is_empty : 'a t -> bool
+       val add : key -> 'a -> 'a t -> 'a t
+       val find : key -> 'a t -> 'a
+       val remove : key -> 'a t -> 'a t
+       val mem : key -> 'a t -> bool
+       val iter : (key -> 'a -> unit) -> 'a t -> unit
+       val map : ('a -> 'b) -> 'a t -> 'b t
+       val mapi : (key -> 'a -> 'b) -> 'a t -> 'b t
+       val fold : (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
+       val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int
+       val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
+      end ;;
+      type kernel_version = string ;;
+    >> in
+
+  let sigs =
+    concat_sig_items [ prologue; types; offsets; parsers; followers ] in
   Printers.OCaml.print_interf ~output_file sigs
 
 (* Finally generate the output files. *)
 let re_subst = Pcre.regexp "^(.*)\"(\\w+_parser_\\d+)\"(.*)$"
 
-let output_implem ~output_file types offsets parsers parser_subs =
-  let new_output_file = output_file ^ ".new" in
+let output_implem ~output_file types offsets parsers parser_subs followers =
+  (* Some standard code that appears at the top of the implementation file. *)
+  let prologue =
+    <:str_item<
+      module StringMap = Map.Make (String) ;;
+      module AddrMap = Map.Make (Int64) ;;
+      type kernel_version = string ;;
+
+      let match_err = "failed to match kernel structure" ;;
+      let struct_missing_err = "struct does not exist in this kernel version" ;;
+
+      let unknown_kernel_version version struct_name =
+       invalid_arg (Printf.sprintf "%s: unknown kernel version or
+struct %s is not supported in this kernel.
+Try a newer version of virt-mem, or if the guest is not from a
+supported Linux distribution, see this page about adding support:
+  http://et.redhat.com/~rjones/virt-mem/faq.html\n"
+                      version struct_name) ;;
 
-  let strs = concat_str_items [ types; offsets; parsers ] in
+      let zero = 0 ;;
+    >> in
+
+  let strs =
+    concat_str_items [ prologue; types; offsets; parsers; followers ] in
+
+  (* Write the new implementation to .ml.new file. *)
+  let new_output_file = output_file ^ ".new" in
   Printers.OCaml.print_implem ~output_file:new_output_file strs;
 
   (* Substitute the parser bodies in the output file. *)