diff --git a/src/lib.rs b/src/lib.rs index fd21c9e..409cd3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ pub enum KError<'a> { Encoding { expected: &'static str }, MissingRoot, MissingParent, + ReadBitsTooLarge { requested: usize }, UnexpectedContents { actual: &'a [u8] }, UnknownVariant(i64), } @@ -76,7 +77,7 @@ pub trait KStream { fn read_f8le(&self) -> KResult; fn align_to_byte(&self) -> KResult<()>; - fn read_bits_int(&self, n: u32) -> KResult; + fn read_bits_int(&self, n: usize) -> KResult; fn read_bytes(&self, len: usize) -> KResult<&[u8]>; fn read_bytes_full(&self) -> KResult<&[u8]>; @@ -243,9 +244,13 @@ impl<'a> KStream for BytesReader<'a> { Ok(()) } - // TODO: Clean up the casting nightmare - fn read_bits_int(&self, n: u32) -> KResult { - let bits_needed = n as i64 - self.state.borrow().bits_left; + fn read_bits_int(&self, n: usize) -> KResult { + if n > 64 { + return Err(KError::ReadBitsTooLarge { requested: n }); + } + + let n = n as i64; + let bits_needed = n - self.state.borrow().bits_left; if bits_needed > 0 { // 1 bit => 1 byte // 8 bits => 1 byte @@ -253,7 +258,6 @@ impl<'a> KStream for BytesReader<'a> { let bytes_needed = ((bits_needed - 1) / 8) + 1; // Need to be careful here, because `read_bytes` will borrow our state as mutable, // which panics if we're currently holding a borrow - // TODO: Return error for bytes_needed > 8 let buf = self.read_bytes(bytes_needed as usize)?; let mut inner = self.state.borrow_mut(); for b in buf { @@ -265,12 +269,12 @@ impl<'a> KStream for BytesReader<'a> { let mut inner = self.state.borrow_mut(); let mut mask = (1u64 << n) - 1; - let shift_bits = inner.bits_left - n as i64; + let shift_bits = inner.bits_left - n; mask <<= shift_bits; let result: u64 = (inner.bits & mask) >> shift_bits; - inner.bits_left -= n as i64; + inner.bits_left -= n; mask = (1u64 << inner.bits_left) - 1; inner.bits &= mask; @@ -364,4 +368,12 @@ mod tests { assert_eq!(reader.read_bits_int(9).unwrap(), 3); } + + #[test] + fn read_bits_too_large() { + let b: Vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + let mut reader = BytesReader::new(&b[..]); + + assert_eq!(reader.read_bits_int(65).unwrap_err(), KError::ReadBitsTooLarge { requested: 65 }) + } }