Add signed int extract and construction functions, and test.
[ocaml-bitstring.git] / bitstring.ml
index 7ed7a58..bc29a41 100644 (file)
@@ -38,6 +38,8 @@ exception Construct_failure of string * string * int * int
  *)
 type bitstring = string * int * int
 
+type t = bitstring
+
 (* Functions to create and load bitstrings. *)
 let empty_bitstring = "", 0, 0
 
@@ -128,17 +130,17 @@ let bitstring_length (_, _, len) = len
 
 let subbitstring (data, off, len) off' len' =
   let off = off + off' in
-  if len < off' + len' then invalid_arg "subbitstring";
+  if off' < 0 || len' < 0 || off' > len - len' then invalid_arg "subbitstring";
   (data, off, len')
 
 let dropbits n (data, off, len) =
   let off = off + n in
   let len = len - n in
-  if len < 0 then invalid_arg "dropbits";
+  if len < 0 || n < 0 then invalid_arg "dropbits";
   (data, off, len)
 
 let takebits n (data, off, len) =
-  if len < n then invalid_arg "takebits";
+  if len < n || n < 0 then invalid_arg "takebits";
   (data, off, n)
 
 (*----------------------------------------------------------------------*)
@@ -159,7 +161,8 @@ module I = struct
 
   (* Create a mask 0-31 bits wide. *)
   let mask bits =
-    if bits < 30 then
+    if bits < 30 || 
+      (bits < 32 && Sys.word_size = 64) then
       (one <<< bits) - 1
     else if bits = 30 then
       max_int
@@ -196,6 +199,19 @@ module I = struct
     let mask = lnot (mask bits) in
     (v land mask) = zero
 
+  let range_signed v bits =
+    if 
+      v >= zero 
+    then
+      range_unsigned v bits
+    else
+      if
+       bits = 31 && Sys.word_size = 32
+      then
+       v >= min_int            
+      else
+       pred (minus_one <<< pred bits) < v
+
   (* Call function g on the top bits, then f on each full byte
    * (big endian - so start at top).
    *)
@@ -397,6 +413,16 @@ let _get_byte32 data byteoff strlen =
 let _get_byte64 data byteoff strlen =
   if strlen > byteoff then Int64.of_int (Char.code data.[byteoff]) else 0L
 
+(* Extend signed [2..31] bits int to 31 bits int or 63 bits int for 64
+   bits platform*)
+let extend_sign len v =
+  let b = pred Sys.word_size - len in
+    (v lsl b) asr b
+
+let extract_and_extend_sign f data off len flen =
+  let w = f data off len flen in
+    extend_sign len w
+
 (* Extract [2..8] bits.  Because the result fits into a single
  * byte we don't have to worry about endianness, only signedness.
  *)
@@ -427,6 +453,9 @@ let extract_char_unsigned data off len flen =
     word (*, off+flen, len-flen*)
   )
 
+let extract_char_signed =
+  extract_and_extend_sign extract_char_unsigned
+
 (* Extract [9..31] bits.  We have to consider endianness and signedness. *)
 let extract_int_be_unsigned data off len flen =
   let byteoff = off lsr 3 in
@@ -470,21 +499,33 @@ let extract_int_be_unsigned data off len flen =
     ) in
   word (*, off+flen, len-flen*)
 
+let extract_int_be_signed =
+  extract_and_extend_sign extract_int_be_unsigned
+
 let extract_int_le_unsigned data off len flen =
   let v = extract_int_be_unsigned data off len flen in
   let v = I.byteswap v flen in
   v
 
+let extract_int_le_signed =
+  extract_and_extend_sign extract_int_le_unsigned
+
 let extract_int_ne_unsigned =
   if nativeendian = BigEndian
   then extract_int_be_unsigned
   else extract_int_le_unsigned
 
+let extract_int_ne_signed = 
+  extract_and_extend_sign extract_int_ne_unsigned
+
 let extract_int_ee_unsigned = function
   | BigEndian -> extract_int_be_unsigned
   | LittleEndian -> extract_int_le_unsigned
   | NativeEndian -> extract_int_ne_unsigned
 
+let extract_int_ee_signed e =
+  extract_and_extend_sign (extract_int_ee_unsigned e)
+
 let _make_int32_be c0 c1 c2 c3 =
   Int32.logor
     (Int32.logor
@@ -851,7 +892,8 @@ module Buffer = struct
           *)
          let slenbytes = slen lsr 3 in
          if slenbytes > 0 then Buffer.add_substring buf str 0 slenbytes;
-         let last = Char.code str.[slenbytes] in (* last char *)
+         let lastidx = min slenbytes (String.length str - 1) in
+         let last = Char.code str.[lastidx] in (* last char *)
          let mask = 0xff lsl (8 - (slen land 7)) in
          t.last <- last land mask
        );
@@ -893,30 +935,53 @@ let construct_char_unsigned buf v flen exn =
   else
     Buffer._add_bits buf v flen
 
-(* Construct a field of up to 31 bits. *)
-let construct_int_be_unsigned buf v flen exn =
-  (* Check value is within range. *)
-  if not (I.range_unsigned v flen) then raise exn;
-  (* Add the bytes. *)
-  I.map_bytes_be (Buffer._add_bits buf) (Buffer.add_byte buf) v flen
+let construct_char_signed buf v flen exn =
+  let max_val = 1 lsl flen 
+  and min_val = - (1 lsl pred flen) in
+    if v < min_val || v >= max_val then
+       raise exn;
+    if flen = 8 then
+      Buffer.add_byte buf (if v >= 0 then v else 256 + v)
+    else 
+      Buffer._add_bits buf v flen
 
 (* Construct a field of up to 31 bits. *)
-let construct_int_le_unsigned buf v flen exn =
-  (* Check value is within range. *)
-  if not (I.range_unsigned v flen) then raise exn;
-  (* Add the bytes. *)
-  I.map_bytes_le (Buffer._add_bits buf) (Buffer.add_byte buf) v flen
+let construct_int check_func map_func buf v flen exn =
+  if not (check_func v flen) then raise exn;
+  map_func (Buffer._add_bits buf) (Buffer.add_byte buf) v flen
+
+let construct_int_be_unsigned =
+  construct_int I.range_unsigned I.map_bytes_be
+
+let construct_int_be_signed =
+  construct_int I.range_signed I.map_bytes_be
+
+let construct_int_le_unsigned =
+  construct_int I.range_unsigned I.map_bytes_le
+
+let construct_int_le_signed =
+  construct_int I.range_signed I.map_bytes_le
 
 let construct_int_ne_unsigned =
   if nativeendian = BigEndian
   then construct_int_be_unsigned
   else construct_int_le_unsigned
 
+let construct_int_ne_signed =
+  if nativeendian = BigEndian
+  then construct_int_be_signed
+  else construct_int_le_signed
+
 let construct_int_ee_unsigned = function
   | BigEndian -> construct_int_be_unsigned
   | LittleEndian -> construct_int_le_unsigned
   | NativeEndian -> construct_int_ne_unsigned
 
+let construct_int_ee_signed = function
+  | BigEndian -> construct_int_be_signed
+  | LittleEndian -> construct_int_le_signed
+  | NativeEndian -> construct_int_ne_signed
+
 (* Construct a field of exactly 32 bits. *)
 let construct_int32_be_unsigned buf v flen _ =
   Buffer.add_byte buf
@@ -965,13 +1030,11 @@ let construct_int64_le_unsigned buf v flen exn =
 let construct_int64_ne_unsigned =
   if nativeendian = BigEndian
   then construct_int64_be_unsigned
-  else (*construct_int64_le_unsigned*)
-    fun _ _ _ _ -> failwith "construct_int64_le_unsigned"
+  else construct_int64_le_unsigned
 
 let construct_int64_ee_unsigned = function
   | BigEndian -> construct_int64_be_unsigned
-  | LittleEndian -> (*construct_int64_le_unsigned*)
-      (fun _ _ _ _ -> failwith "construct_int64_le_unsigned")
+  | LittleEndian -> construct_int64_le_unsigned
   | NativeEndian -> construct_int64_ne_unsigned
 
 (* Construct from a string of bytes, exact multiple of 8 bits
@@ -992,7 +1055,7 @@ let construct_bitstring buf (data, off, len) =
     if blen = 0 then (off, len)
     else (
       let b = extract_bit data off len 1
-      and off = off + 1 and len = len + 1 in
+      and off = off + 1 and len = len - 1 in
       Buffer.add_bit buf b;
       loop off len (blen-1)
     )
@@ -1009,9 +1072,14 @@ let construct_bitstring buf (data, off, len) =
 
   Buffer.add_bits buf data len
 
+(* Concatenate bitstrings. *)
+let concat bs =
+  let buf = Buffer.create () in
+  List.iter (construct_bitstring buf) bs;
+  Buffer.contents buf
+
 (*----------------------------------------------------------------------*)
 (* Extract a string from a bitstring. *)
-
 let string_of_bitstring (data, off, len) =
   if off land 7 = 0 && len land 7 = 0 then
     (* Easy case: everything is byte-aligned. *)
@@ -1060,6 +1128,110 @@ let bitstring_to_file bits filename =
     raise exn
 
 (*----------------------------------------------------------------------*)
+(* Comparison. *)
+let compare ((data1, off1, len1) as bs1) ((data2, off2, len2) as bs2) =
+  (* In the fully-aligned case, this is reduced to string comparison ... *)
+  if off1 land 7 = 0 && len1 land 7 = 0 && off2 land 7 = 0 && len2 land 7 = 0
+  then (
+    (* ... but we have to do that by hand because the bits may
+     * not extend to the full length of the underlying string.
+     *)
+    let off1 = off1 lsr 3 and off2 = off2 lsr 3
+    and len1 = len1 lsr 3 and len2 = len2 lsr 3 in
+    let rec loop i =
+      if i < len1 && i < len2 then (
+       let c1 = String.unsafe_get data1 (off1 + i)
+       and c2 = String.unsafe_get data2 (off2 + i) in
+       let r = compare c1 c2 in
+       if r <> 0 then r
+       else loop (i+1)
+      )
+      else len1 - len2
+    in
+    loop 0
+  )
+  else (
+    (* Slow/unaligned. *)
+    let str1 = string_of_bitstring bs1
+    and str2 = string_of_bitstring bs2 in
+    let r = String.compare str1 str2 in
+    if r <> 0 then r else len1 - len2
+  )
+
+let equals ((_, _, len1) as bs1) ((_, _, len2) as bs2) =
+  if len1 <> len2 then false
+  else if bs1 = bs2 then true
+  else 0 = compare bs1 bs2
+
+let is_zeroes_bitstring ((data, off, len) as bits) =
+  if off land 7 = 0 && len land 7 = 0 then (
+    let off = off lsr 3 and len = len lsr 3 in
+    let rec loop i =
+      if i < len then (
+        if String.unsafe_get data (off + i) <> '\000' then false
+        else loop (i+1)
+      ) else true
+    in
+    loop 0
+  )
+  else (
+    (* Slow/unaligned case. *)
+    let len = bitstring_length bits in
+    let zeroes = zeroes_bitstring len in
+    0 = compare bits zeroes
+  )
+
+let is_ones_bitstring ((data, off, len) as bits) =
+  if off land 7 = 0 && len land 7 = 0 then (
+    let off = off lsr 3 and len = len lsr 3 in
+    let rec loop i =
+      if i < len then (
+        if String.unsafe_get data (off + i) <> '\xff' then false
+        else loop (i+1)
+      ) else true
+    in
+    loop 0
+  )
+  else (
+    (* Slow/unaligned case. *)
+    let len = bitstring_length bits in
+    let ones = ones_bitstring len in
+    0 = compare bits ones
+  )
+
+(*----------------------------------------------------------------------*)
+(* Bit get/set functions. *)
+
+let index_out_of_bounds () = invalid_arg "index out of bounds"
+
+let put (data, off, len) n v =
+  if n < 0 || n >= len then index_out_of_bounds ()
+  else (
+    let i = off+n in
+    let si = i lsr 3 and mask = 0x80 lsr (i land 7) in
+    let c = Char.code data.[si] in
+    let c = if v <> 0 then c lor mask else c land (lnot mask) in
+    data.[si] <- Char.unsafe_chr c
+  )
+
+let set bits n = put bits n 1
+
+let clear bits n = put bits n 0
+
+let get (data, off, len) n =
+  if n < 0 || n >= len then index_out_of_bounds ()
+  else (
+    let i = off+n in
+    let si = i lsr 3 and mask = 0x80 lsr (i land 7) in
+    let c = Char.code data.[si] in
+    c land mask
+  )
+
+let is_set bits n = get bits n <> 0
+
+let is_clear bits n = get bits n = 0
+
+(*----------------------------------------------------------------------*)
 (* Display functions. *)
 
 let isprint c =
@@ -1102,3 +1274,9 @@ let hexdump_bitstring chan (data, off, len) =
     fprintf chan " |%s|\n%!" linechars
   ) else
     fprintf chan "\n%!"
+
+(*----------------------------------------------------------------------*)
+(* Alias of functions shadowed by Core. *)
+
+let char_code = Char.code
+let int32_of_int = Int32.of_int