Fix corrupted IPv4/6 packets.
[ocaml-bitstring.git] / pa_bitmatch.ml
1 (* Bitmatch syntax extension.
2  * $Id: pa_bitmatch.ml,v 1.3 2008-04-01 10:05:14 rjones Exp $
3  *)
4
5 open Printf
6
7 open Camlp4.PreCast
8 open Syntax
9 open Ast
10
11 type m = Fields of f list               (* field ; field -> ... *)
12        | Bind of string option          (* _ -> ... *)
13 and f = {
14   ident : string;                       (* field name *)
15   flen : expr;                          (* length in bits, may be non-const *)
16   endian : endian;                      (* endianness *)
17   signed : bool;                        (* true if signed, false if unsigned *)
18   t : t;                                (* type *)
19 }
20 and endian = BigEndian | LittleEndian | NativeEndian
21 and t = Int | Bitstring
22
23 (* Generate a fresh, unique symbol each time called. *)
24 let gensym =
25   let i = ref 1000 in
26   fun name ->
27     incr i; let i = !i in
28     sprintf "__pabitmatch_%s_%d" name i
29
30 (* Deal with the qualifiers which appear for a field. *)
31 let output_field _loc name flen qs =
32   let endian, signed, t =
33     match qs with
34     | None -> (None, None, None)
35     | Some qs ->
36         List.fold_left (
37           fun (endian, signed, t) q ->
38             match q with
39             | "bigendian" ->
40                 if endian <> None then
41                   Loc.raise _loc (Failure "an endian flag has been set already")
42                 else (
43                   let endian = Some BigEndian in
44                   (endian, signed, t)
45                 )
46             | "littleendian" ->
47                 if endian <> None then
48                   Loc.raise _loc (Failure "an endian flag has been set already")
49                 else (
50                   let endian = Some LittleEndian in
51                   (endian, signed, t)
52                 )
53             | "nativeendian" ->
54                 if endian <> None then
55                   Loc.raise _loc (Failure "an endian flag has been set already")
56                 else (
57                   let endian = Some NativeEndian in
58                   (endian, signed, t)
59                 )
60             | "signed" ->
61                 if signed <> None then
62                   Loc.raise _loc (Failure "a signed flag has been set already")
63                 else (
64                   let signed = Some true in
65                   (endian, signed, t)
66                 )
67             | "unsigned" ->
68                 if signed <> None then
69                   Loc.raise _loc (Failure "a signed flag has been set already")
70                 else (
71                   let signed = Some false in
72                   (endian, signed, t)
73                 )
74             | "int" ->
75                 if t <> None then
76                   Loc.raise _loc (Failure "a type flag has been set already")
77                 else (
78                   let t = Some Int in
79                   (endian, signed, t)
80                 )
81             | "bitstring" ->
82                 if t <> None then
83                   Loc.raise _loc (Failure "a type flag has been set already")
84                 else (
85                   let t = Some Bitstring in
86                   (endian, signed, t)
87                 )
88             | s ->
89                 Loc.raise _loc (Failure (s ^ ": unknown qualifier"))
90         ) (None, None, None) qs in
91
92   (* If type is set to bitstring then endianness and signedness
93    * qualifiers are meaningless and must not be set.
94    *)
95   if t = Some Bitstring && (endian <> None || signed <> None) then
96     Loc.raise _loc (
97       Failure "bitstring type and endian or signed qualifiers cannot be mixed"
98     );
99
100   (* Default endianness, signedness, type. *)
101   let endian = match endian with None -> BigEndian | Some e -> e in
102   let signed = match signed with None -> false | Some s -> s in
103   let t = match t with None -> Int | Some t -> t in
104
105   {
106     ident = name;
107     flen = flen;
108     endian = endian;
109     signed = signed;
110     t = t;
111   }
112
113 (* Generate the code for a bitmatch statement.  '_loc' is the
114  * location, 'bs' is the bitstring parameter, 'cases' are
115  * the list of cases to test against.
116  *)
117 let output_bitmatch _loc bs cases =
118   let data = gensym "data" and off = gensym "off" and len = gensym "len" in
119   let result = gensym "result" in
120
121   (* This generates the field extraction code for each
122    * field a single case.  Each field must be wider than
123    * the minimum permitted for the type and there must be
124    * enough remaining data in the bitstring to satisfy it.
125    * As we go through the fields, symbols 'data', 'off' and 'len'
126    * track our position and remaining length in the bitstring.
127    *
128    * The whole thing is a lot of nested 'if' statements. Code
129    * is generated from the inner-most (last) field outwards.
130    *)
131   let rec output_field_extraction inner = function
132     | [] -> inner
133     | {ident=ident; flen=flen; endian=endian; signed=signed; t=t} :: fields ->
134         (* If length an integer constant?  If so, what is it?  This
135          * is very simple-minded and only detects simple constants.
136          *)
137         let flen_is_const =
138           match flen with
139           | <:expr< $int:i$ >> -> Some (int_of_string i)
140           | _ -> None in
141
142         let name_of_int_extract_const = function
143             (* XXX As an enhancement we should allow a 64-bit-only
144              * mode which lets us use 'int' up to 63 bits and won't
145              * compile on 32-bit platforms.
146              *)
147             (* XXX The meaning of signed/unsigned breaks down at
148              * 31, 32, 63 and 64 bits.
149              *)
150           | (1, _, _) -> "extract_bit"
151           | ((2|3|4|5|6|7|8), _, false) -> "extract_char_unsigned"
152           | ((2|3|4|5|6|7|8), _, true) -> "extract_char_signed"
153           | (i, BigEndian, false) when i <= 31 -> "extract_int_be_unsigned"
154           | (i, BigEndian, true) when i <= 31 -> "extract_int_be_signed"
155           | (i, LittleEndian, false) when i <= 31 -> "extract_int_le_unsigned"
156           | (i, LittleEndian, true) when i <= 31 -> "extract_int_le_signed"
157           | (i, NativeEndian, false) when i <= 31 -> "extract_int_ne_unsigned"
158           | (i, NativeEndian, true) when i <= 31 -> "extract_int_ne_signed"
159           | (32, BigEndian, false) -> "extract_int32_be_unsigned"
160           | (32, BigEndian, true) -> "extract_int32_be_signed"
161           | (32, LittleEndian, false) -> "extract_int32_le_unsigned"
162           | (32, LittleEndian, true) -> "extract_int32_le_signed"
163           | (32, NativeEndian, false) -> "extract_int32_ne_unsigned"
164           | (32, NativeEndian, true) -> "extract_int32_ne_signed"
165           | (_, BigEndian, false) -> "extract_int64_be_unsigned"
166           | (_, BigEndian, true) -> "extract_int64_be_signed"
167           | (_, LittleEndian, false) -> "extract_int64_le_unsigned"
168           | (_, LittleEndian, true) -> "extract_int64_le_signed"
169           | (_, NativeEndian, false) -> "extract_int64_ne_unsigned"
170           | (_, NativeEndian, true) -> "extract_int64_ne_signed"
171         in
172         let name_of_int_extract = function
173             (* XXX As an enhancement we should allow users to
174              * specify that a field length can fit into a char/int/int32
175              * (of course, this would have to be checked at runtime).
176              *)
177           | (BigEndian, false) -> "extract_int64_be_unsigned"
178           | (BigEndian, true) -> "extract_int64_be_signed"
179           | (LittleEndian, false) -> "extract_int64_le_unsigned"
180           | (LittleEndian, true) -> "extract_int64_le_signed"
181           | (NativeEndian, false) -> "extract_int64_ne_unsigned"
182           | (NativeEndian, true) -> "extract_int64_ne_signed"
183         in
184
185         let expr =
186           match t, flen_is_const with
187           (* Common case: int field, constant flen *)
188           | Int, Some i when i > 0 && i <= 64 ->
189               let extract_func = name_of_int_extract_const (i,endian,signed) in
190               <:expr<
191                 if $lid:len$ >= $flen$ then (
192                   let $lid:ident$, $lid:off$, $lid:len$ =
193                     Bitmatch.$lid:extract_func$ $lid:data$ $lid:off$ $lid:len$
194                       $flen$ in
195                   $inner$
196                 )
197               >>
198
199           | Int, Some _ ->
200               Loc.raise _loc (Failure "length of int field must be [1..64]")
201
202           (* Int field, non-const flen.  We have to test the range of
203            * the field at runtime.  If outside the range it's a no-match
204            * (not an error).
205            *)
206           | Int, None ->
207               let extract_func = name_of_int_extract (endian,signed) in
208               <:expr<
209                 if $flen$ >= 1 && $flen$ <= 64 && $flen$ >= $lid:len$ then (
210                   let $lid:ident$, $lid:off$, $lid:len$ =
211                     Bitmatch.$lid:extract_func$ $lid:data$ $lid:off$ $lid:len$
212                       $flen$ in
213                   $inner$
214                 )
215               >>
216
217           (* Bitstring, constant flen >= 0. *)
218           | Bitstring, Some i when i >= 0 ->
219               <:expr<
220                 if $lid:len$ >= $flen$ then (
221                   let $lid:ident$, $lid:off$, $lid:len$ =
222                     Bitmatch.extract_bitstring $lid:data$ $lid:off$ $lid:len$
223                       $flen$ in
224                   $inner$
225                 )
226               >>
227
228           (* Bitstring, constant flen = -1, means consume all the
229            * rest of the input.
230            *)
231           | Bitstring, Some i when i = -1 ->
232               <:expr<
233                 let $lid:ident$, $lid:off$, $lid:len$ =
234                   Bitmatch.extract_remainder $lid:data$ $lid:off$ $lid:len$ in
235                   $inner$
236               >>
237
238           | Bitstring, Some _ ->
239               Loc.raise _loc (Failure "length of bitstring must be >= 0 or the special value -1")
240
241           (* Bitstring field, non-const flen.  We check the flen is >= 0
242            * (-1 is not allowed here) at runtime.
243            *)
244           | Bitstring, None ->
245               <:expr<
246                 if $flen$ >= 0 && $lid:len$ >= $flen$ then (
247                   let $lid:ident$, $lid:off$, $lid:len$ =
248                     Bitmatch.extract_bitstring $lid:data$ $lid:off$ $lid:len$
249                       $flen$ in
250                   $inner$
251                 )
252               >>
253         in
254
255         output_field_extraction expr fields
256   in
257
258   (* Convert each case in the match. *)
259   let cases = List.map (
260     function
261     (* field : len ; field : len when .. -> ..*)
262     | (Fields fields, Some whenclause, code) ->
263         let inner =
264           <:expr<
265             if $whenclause$ then (
266               $lid:result$ := Some ($code$);
267               raise Exit
268             )
269           >> in
270         output_field_extraction inner (List.rev fields)
271
272     (* field : len ; field : len -> ... *)
273     | (Fields fields, None, code) ->
274         let inner =
275           <:expr<
276             $lid:result$ := Some ($code$);
277             raise Exit
278           >> in
279         output_field_extraction inner (List.rev fields)
280
281     (* _ as name when ... -> ... *)
282     | (Bind (Some name), Some whenclause, code) ->
283         <:expr<
284           let $lid:name$ = ($lid:data$, $lid:off$, $lid:len$) in
285           if $whenclause$ then (
286             $lid:result$ := Some ($code$);
287             raise Exit
288           )
289         >>
290
291     (* _ as name -> ... *)
292     | (Bind (Some name), None, code) ->
293         <:expr<
294           let $lid:name$ = ($lid:data$, $lid:off$, $lid:len$) in
295           $lid:result$ := Some ($code$);
296           raise Exit
297         >>
298
299     (* _ when ... -> ... *)
300     | (Bind None, Some whenclause, code) ->
301         <:expr<
302           if $whenclause$ then (
303             $lid:result$ := Some ($code$);
304             raise Exit
305           )
306         >>
307
308     (* _ -> ... *)
309     | (Bind None, None, code) ->
310         <:expr<
311           $lid:result$ := Some ($code$);
312           raise Exit
313         >>
314
315   ) cases in
316
317   (* Join them into a single expression.
318    *
319    * Don't do it with a normal fold_right because that leaves
320    * 'raise Exit; ()' at the end which causes a compiler warning.
321    * Hence a bit of complexity here.
322    *
323    * Note that the number of cases is always >= 1 so List.hd is safe.
324    *)
325   let cases = List.rev cases in
326   let cases =
327     List.fold_left (fun base case -> <:expr< $case$ ; $base$ >>)
328       (List.hd cases) (List.tl cases) in
329
330   (* The final code just wraps the list of cases in a
331    * try/with construct so that each case is tried in
332    * turn until one case matches (that case sets 'result'
333    * and raises 'Exit' to leave the whole statement).
334    * If result isn't set by the end then we will raise
335    * Match_failure with the location of the bitmatch
336    * statement in the original code.
337    *)
338   let loc_fname = Loc.file_name _loc in
339   let loc_line = string_of_int (Loc.start_line _loc) in
340   let loc_char = string_of_int (Loc.start_off _loc - Loc.start_bol _loc) in
341
342   <:expr<
343     let ($lid:data$, $lid:off$, $lid:len$) = $bs$ in
344     let $lid:result$ = ref None in
345     (try
346       $cases$
347     with Exit -> ());
348     match ! $lid:result$ with
349     | Some x -> x
350     | None -> raise (Match_failure ($str:loc_fname$,
351                                     $int:loc_line$, $int:loc_char$))
352   >>
353
354 EXTEND Gram
355   GLOBAL: expr;
356
357   qualifiers: [
358     [ LIST0 [ q = LIDENT -> q ] SEP "," ]
359   ];
360
361   field: [
362     [ name = LIDENT; ":"; len = expr LEVEL "top";
363       qs = OPT [ ":"; qs = qualifiers -> qs ] ->
364         output_field _loc name len qs
365     ]
366   ];
367
368   match_case: [
369     [ fields = LIST0 field SEP ";";
370       w = OPT [ "when"; e = expr -> e ]; "->";
371       code = expr ->
372         (Fields fields, w, code)
373     ]
374   | [ "_";
375       bind = OPT [ "as"; name = LIDENT -> name ];
376       w = OPT [ "when"; e = expr -> e ]; "->";
377       code = expr ->
378         (Bind bind, w, code)
379     ]
380   ];
381
382   expr: LEVEL ";" [
383     [ "bitmatch"; bs = expr; "with"; OPT "|";
384       cases = LIST1 match_case SEP "|" ->
385         output_bitmatch _loc bs cases
386     ]
387   ];
388
389 END