Bitmatch syntax extension, working on bits and bitstrings.
[ocaml-bitstring.git] / pa_bitmatch.ml
1 (* Bitmatch syntax extension.
2  * $Id: pa_bitmatch.ml,v 1.1 2008-03-31 22:52:17 rjones Exp $
3  *)
4
5 open Printf
6
7 open Camlp4.PreCast
8 open Syntax
9 open Ast
10
11 type m = Fields of f list               (* field ; field -> ... *)
12        | Bind of string option          (* _ -> ... *)
13 and f = {
14   ident : string;                       (* field name *)
15   flen : expr;                          (* length in bits, may be non-const *)
16   endian : endian;                      (* endianness *)
17   signed : bool;                        (* true if signed, false if unsigned *)
18   t : t;                                (* type *)
19 }
20 and endian = BigEndian | LittleEndian | NativeEndian
21 and t = Int | Bitstring
22
23 (* Generate a fresh, unique symbol each time called. *)
24 let gensym =
25   let i = ref 1000 in
26   fun name ->
27     incr i; let i = !i in
28     sprintf "__pabitmatch_%s_%d" name i
29
30 (* Deal with the qualifiers which appear for a field. *)
31 let output_field _loc name flen qs =
32   let endian, signed, t =
33     match qs with
34     | None -> (None, None, None)
35     | Some qs ->
36         List.fold_left (
37           fun (endian, signed, t) q ->
38             match q with
39             | "bigendian" ->
40                 if endian <> None then
41                   Loc.raise _loc (Failure "an endian flag has been set already")
42                 else (
43                   let endian = Some BigEndian in
44                   (endian, signed, t)
45                 )
46             | "littleendian" ->
47                 if endian <> None then
48                   Loc.raise _loc (Failure "an endian flag has been set already")
49                 else (
50                   let endian = Some LittleEndian in
51                   (endian, signed, t)
52                 )
53             | "nativeendian" ->
54                 if endian <> None then
55                   Loc.raise _loc (Failure "an endian flag has been set already")
56                 else (
57                   let endian = Some NativeEndian in
58                   (endian, signed, t)
59                 )
60             | "signed" ->
61                 if signed <> None then
62                   Loc.raise _loc (Failure "a signed flag has been set already")
63                 else (
64                   let signed = Some true in
65                   (endian, signed, t)
66                 )
67             | "unsigned" ->
68                 if signed <> None then
69                   Loc.raise _loc (Failure "a signed flag has been set already")
70                 else (
71                   let signed = Some false in
72                   (endian, signed, t)
73                 )
74             | "int" ->
75                 if t <> None then
76                   Loc.raise _loc (Failure "a type flag has been set already")
77                 else (
78                   let t = Some Int in
79                   (endian, signed, t)
80                 )
81             | "bitstring" ->
82                 if t <> None then
83                   Loc.raise _loc (Failure "a type flag has been set already")
84                 else (
85                   let t = Some Bitstring in
86                   (endian, signed, t)
87                 )
88             | s ->
89                 Loc.raise _loc (Failure (s ^ ": unknown qualifier"))
90         ) (None, None, None) qs in
91
92   (* If type is set to bitstring then endianness and signedness
93    * qualifiers are meaningless and must not be set.
94    *)
95   if t = Some Bitstring && (endian <> None || signed <> None) then
96     Loc.raise _loc (
97       Failure "bitstring type and endian or signed qualifiers cannot be mixed"
98     );
99
100   (* Default endianness, signedness, type. *)
101   let endian = match endian with None -> BigEndian | Some e -> e in
102   let signed = match signed with None -> false | Some s -> s in
103   let t = match t with None -> Int | Some t -> t in
104
105   {
106     ident = name;
107     flen = flen;
108     endian = endian;
109     signed = signed;
110     t = t;
111   }
112
113 (* Generate the code for a bitmatch statement.  '_loc' is the
114  * location, 'bs' is the bitstring parameter, 'cases' are
115  * the list of cases to test against.
116  *)
117 let output_bitmatch _loc bs cases =
118   let data = gensym "data" and off = gensym "off" and len = gensym "len" in
119   let result = gensym "result" in
120
121   (* This generates the field extraction code for each
122    * field a single case.  Each field must be wider than
123    * the minimum permitted for the type and there must be
124    * enough remaining data in the bitstring to satisfy it.
125    * As we go through the fields, symbols 'data', 'off' and 'len'
126    * track our position and remaining length in the bitstring.
127    *
128    * The whole thing is a lot of nested 'if' statements. Code
129    * is generated from the inner-most (last) field outwards.
130    *)
131   let rec output_field_extraction inner = function
132     | [] -> inner
133     | {ident=ident; flen=flen; endian=endian; signed=signed; t=t} :: fields ->
134         (* If length an integer constant?  If so, what is it?  This
135          * is very simple-minded and only detects simple constants.
136          *)
137         let flen_is_const =
138           match flen with
139           | <:expr< $int:i$ >> -> Some (int_of_string i)
140           | _ -> None in
141
142         let name_of_int_extract_const = function
143           | (1, _, _) -> "extract_bit"
144           | ((2|3|4|5|6|7), _, false) -> "extract_char_unsigned"
145           | ((2|3|4|5|6|7), _, true) -> "extract_char_signed"
146           | (i, BigEndian, false) when i <= 31 -> "extract_int_be_unsigned"
147           | (i, BigEndian, true) when i <= 31 -> "extract_int_be_signed"
148           | (i, LittleEndian, false) when i <= 31 -> "extract_int_le_unsigned"
149           | (i, LittleEndian, true) when i <= 31 -> "extract_int_le_signed"
150           | (i, NativeEndian, false) when i <= 31 -> "extract_int_ne_unsigned"
151           | (i, NativeEndian, true) when i <= 31 -> "extract_int_ne_signed"
152           | (32, BigEndian, false) -> "extract_int32_be_unsigned"
153           | (32, BigEndian, true) -> "extract_int32_be_signed"
154           | (32, LittleEndian, false) -> "extract_int32_le_unsigned"
155           | (32, LittleEndian, true) -> "extract_int32_le_signed"
156           | (32, NativeEndian, false) -> "extract_int32_ne_unsigned"
157           | (32, NativeEndian, true) -> "extract_int32_ne_signed"
158           | (_, BigEndian, false) -> "extract_int64_be_unsigned"
159           | (_, BigEndian, true) -> "extract_int64_be_signed"
160           | (_, LittleEndian, false) -> "extract_int64_le_unsigned"
161           | (_, LittleEndian, true) -> "extract_int64_le_signed"
162           | (_, NativeEndian, false) -> "extract_int64_ne_unsigned"
163           | (_, NativeEndian, true) -> "extract_int64_ne_signed"
164         in
165         let name_of_int_extract = function
166             (* XXX As an enhancement we should allow users to
167              * specify that a field length can fit into a char/int/int32
168              * (of course, this would have to be checked at runtime).
169              *)
170           | (BigEndian, false) -> "extract_int64_be_unsigned"
171           | (BigEndian, true) -> "extract_int64_be_signed"
172           | (LittleEndian, false) -> "extract_int64_le_unsigned"
173           | (LittleEndian, true) -> "extract_int64_le_signed"
174           | (NativeEndian, false) -> "extract_int64_ne_unsigned"
175           | (NativeEndian, true) -> "extract_int64_ne_signed"
176         in
177
178         let expr =
179           match t, flen_is_const with
180           (* Common case: int field, constant flen *)
181           | Int, Some i when i > 0 && i <= 64 ->
182               let extract_func = name_of_int_extract_const (i,endian,signed) in
183               <:expr<
184                 if $lid:len$ >= $flen$ then (
185                   let $lid:ident$, $lid:off$, $lid:len$ =
186                     Bitmatch.$lid:extract_func$ $lid:data$ $lid:off$ $lid:len$
187                       $flen$ in
188                   $inner$
189                 )
190               >>
191
192           | Int, Some _ ->
193               Loc.raise _loc (Failure "length of int field must be [1..64]")
194
195           (* Int field, non-const flen.  We have to test the range of
196            * the field at runtime.  If outside the range it's a no-match
197            * (not an error).
198            *)
199           | Int, None ->
200               let extract_func = name_of_int_extract (endian,signed) in
201               <:expr<
202                 if $flen$ >= 1 && $flen$ <= 64 && $flen$ >= $lid:len$ then (
203                   let $lid:ident$, $lid:off$, $lid:len$ =
204                     Bitmatch.$lid:extract_func$ $lid:data$ $lid:off$ $lid:len$
205                       $flen$ in
206                   $inner$
207                 )
208               >>
209
210           (* Bitstring, constant flen >= 0. *)
211           | Bitstring, Some i when i >= 0 ->
212               <:expr<
213                 if $lid:len$ >= $flen$ then (
214                   let $lid:ident$, $lid:off$, $lid:len$ =
215                     Bitmatch.extract_bitstring $lid:data$ $lid:off$ $lid:len$
216                       $flen$ in
217                   $inner$
218                 )
219               >>
220
221           (* Bitstring, constant flen = -1, means consume all the
222            * rest of the input.
223            *)
224           | Bitstring, Some i when i = -1 ->
225               <:expr<
226                 let $lid:ident$, $lid:off$, $lid:len$ =
227                   Bitmatch.extract_remainder $lid:data$ $lid:off$ $lid:len$ in
228                   $inner$
229               >>
230
231           | Bitstring, Some _ ->
232               Loc.raise _loc (Failure "length of bitstring must be >= 0 or the special value -1")
233
234           (* Bitstring field, non-const flen.  We check the flen is >= 0
235            * (-1 is not allowed here) at runtime.
236            *)
237           | Bitstring, None ->
238               <:expr<
239                 if $flen$ >= 0 && $lid:len$ >= $flen$ then (
240                   let $lid:ident$, $lid:off$, $lid:len$ =
241                     Bitmatch.extract_bitstring $lid:data$ $lid:off$ $lid:len$
242                       $flen$ in
243                   $inner$
244                 )
245               >>
246         in
247
248         output_field_extraction expr fields
249   in
250
251   (* Convert each case in the match. *)
252   let cases = List.map (
253     function
254     (* field : len ; field : len when .. -> ..*)
255     | (Fields fields, Some whenclause, code) ->
256         let inner =
257           <:expr<
258             if $whenclause$ then (
259               $lid:result$ := Some ($code$);
260               raise Exit
261             )
262           >> in
263         output_field_extraction inner (List.rev fields)
264
265     (* field : len ; field : len -> ... *)
266     | (Fields fields, None, code) ->
267         let inner =
268           <:expr<
269             $lid:result$ := Some ($code$);
270             raise Exit
271           >> in
272         output_field_extraction inner (List.rev fields)
273
274     (* _ as name when ... -> ... *)
275     | (Bind (Some name), Some whenclause, code) ->
276         <:expr<
277           let $lid:name$ = ($lid:data$, $lid:off$, $lid:len$) in
278           if $whenclause$ then (
279             $lid:result$ := Some ($code$);
280             raise Exit
281           )
282         >>
283
284     (* _ as name -> ... *)
285     | (Bind (Some name), None, code) ->
286         <:expr<
287           let $lid:name$ = ($lid:data$, $lid:off$, $lid:len$) in
288           $lid:result$ := Some ($code$);
289           raise Exit
290         >>
291
292     (* _ when ... -> ... *)
293     | (Bind None, Some whenclause, code) ->
294         <:expr<
295           if $whenclause$ then (
296             $lid:result$ := Some ($code$);
297             raise Exit
298           )
299         >>
300
301     (* _ -> ... *)
302     | (Bind None, None, code) ->
303         <:expr<
304           $lid:result$ := Some ($code$);
305           raise Exit
306         >>
307
308   ) cases in
309
310   let cases =
311     List.fold_right (fun case base -> <:expr< $case$ ; $base$ >>)
312       cases <:expr< () >> in
313
314   (* The final code just wraps the list of cases in a
315    * try/with construct so that each case is tried in
316    * turn until one case matches (that case sets 'result'
317    * and raises 'Exit' to leave the whole statement).
318    * If result isn't set by the end then we will raise
319    * Match_failure with the location of the bitmatch
320    * statement in the original code.
321    *)
322   let loc_fname = Loc.file_name _loc in
323   let loc_line = string_of_int (Loc.start_line _loc) in
324   let loc_char = string_of_int (Loc.start_off _loc - Loc.start_bol _loc) in
325
326   <:expr<
327     let ($lid:data$, $lid:off$, $lid:len$) = $bs$ in
328     let $lid:result$ = ref None in
329     (try
330       $cases$
331     with Exit -> ());
332     match ! $lid:result$ with
333     | Some x -> x
334     | None -> raise (Match_failure ($str:loc_fname$,
335                                     $int:loc_line$, $int:loc_char$))
336   >>
337
338 EXTEND Gram
339   GLOBAL: expr;
340
341   qualifiers: [
342     [ LIST0 [ q = LIDENT -> q ] SEP "," ]
343   ];
344
345   field: [
346     [ name = LIDENT; ":"; len = expr LEVEL "top";
347       qs = OPT [ ":"; qs = qualifiers -> qs ] ->
348         output_field _loc name len qs
349     ]
350   ];
351
352   match_case: [
353     [ fields = LIST0 field SEP ";";
354       w = OPT [ "when"; e = expr -> e ]; "->";
355       code = expr ->
356         (Fields fields, w, code)
357     ]
358   | [ "_";
359       bind = OPT [ "as"; name = LIDENT -> name ];
360       w = OPT [ "when"; e = expr -> e ]; "->";
361       code = expr ->
362         (Bind bind, w, code)
363     ]
364   ];
365
366   expr: LEVEL ";" [
367     [ "bitmatch"; bs = expr; "with"; OPT "|";
368       cases = LIST1 match_case SEP "|" ->
369         output_bitmatch _loc bs cases
370     ]
371   ];
372
373 END