diff --git a/src/lib.rs b/src/lib.rs index 21f8e1a..c5294be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,8 +131,8 @@ pub trait KStream { #[derive(Default)] struct BytesReaderState { pos: usize, - bits: u8, - bits_left: u8, + bits: u64, + bits_left: i64, } pub struct BytesReader<'a> { state: RefCell, @@ -239,8 +239,38 @@ impl<'a> KStream for BytesReader<'a> { unimplemented!() } + // TODO: Clean up the casting nightmare fn read_bits_int(&self, n: u32) -> KResult { - unimplemented!() + let bits_needed = n as i64 - self.state.borrow().bits_left; + if bits_needed > 0 { + // 1 bit => 1 byte + // 8 bits => 1 byte + // 9 bits => 2 bytes + let bytes_needed = ((bits_needed - 1) / 8) + 1; + // Need to be careful here, because `read_bytes` will borrow our state as mutable, + // which panics if we're currently holding a borrow + // TODO: Return error for bytes_needed > 8 + let buf = self.read_bytes(bytes_needed as usize)?; + let mut inner = self.state.borrow_mut(); + for b in buf { + inner.bits <<= 8; + inner.bits |= *b as u64; + inner.bits_left += 8; + } + } + + let mut inner = self.state.borrow_mut(); + let mut mask = (1u64 << n) - 1; + let shift_bits = inner.bits_left - n as i64; + mask <<= shift_bits; + + let result: u64 = (inner.bits & mask) >> shift_bits; + + inner.bits_left -= n as i64; + mask = (1u64 << inner.bits_left) - 1; + inner.bits &= mask; + + Ok(result) } fn read_bytes(&self, len: usize) -> KResult<&[u8]> { @@ -295,4 +325,39 @@ mod tests { ); assert_eq!(reader.read_bytes(1).unwrap(), &[8]); } + + #[test] + fn read_bits_single() { + let b = vec![0x80]; + let mut reader = BytesReader::new(&b[..]); + + assert_eq!(reader.read_bits_int(1).unwrap(), 1); + } + + #[test] + fn read_bits_multiple() { + // 0xA0 + let b = vec![0b10100000]; + let mut reader = BytesReader::new(&b[..]); + + assert_eq!(reader.read_bits_int(1).unwrap(), 1); + assert_eq!(reader.read_bits_int(1).unwrap(), 0); + assert_eq!(reader.read_bits_int(1).unwrap(), 1); + } + + #[test] + fn read_bits_large() { + let b = vec![0b10100000]; + let mut reader = BytesReader::new(&b[..]); + + assert_eq!(reader.read_bits_int(3).unwrap(), 5); + } + + #[test] + fn read_bits_span() { + let b = vec![0x01, 0x80]; + let mut reader = BytesReader::new(&b[..]); + + assert_eq!(reader.read_bits_int(9).unwrap(), 3); + } }