diff --git a/aeron-rs/src/command/flyweight.rs b/aeron-rs/src/command/flyweight.rs index 07cf6d8..e012a3e 100644 --- a/aeron-rs/src/command/flyweight.rs +++ b/aeron-rs/src/command/flyweight.rs @@ -14,7 +14,7 @@ where _phantom: PhantomData, } -/// Marker struct. +/// Marker struct for an uninitialized `Flyweight` object // We can't put this `new` method in the fully generic implementation because // Rust gets confused as to what type `S` should be. pub struct Unchecked; @@ -44,23 +44,19 @@ where A: AtomicBuffer, S: Sized, { - pub(crate) fn get_struct(&self) -> &S { + pub(in crate::command) fn get_struct(&self) -> &S { // UNWRAP: Bounds check performed during initialization self.buffer.overlay::(self.base_offset).unwrap() } - pub(crate) fn get_struct_mut(&mut self) -> &mut S { + pub(in crate::command) fn get_struct_mut(&mut self) -> &mut S { // UNWRAP: Bounds check performed during initialization self.buffer.overlay_mut::(self.base_offset).unwrap() } - pub(crate) fn bytes_at(&self, offset: IndexT) -> &[u8] { + pub(in crate::command) fn bytes_at(&self, offset: IndexT) -> Result<&[u8]> { let offset = (self.base_offset + offset) as usize; - // FIXME: Unwrap is unjustified here. - // C++ uses pointer arithmetic with no bounds checking, so I'm more comfortable - // with the Rust version at least panicking. Is the idea that we're safe because - // this is a crate-local (protected in C++) method? - self.buffer.bounds_check(offset as IndexT, 0).unwrap(); - &self.buffer[offset..] + self.buffer.bounds_check(offset as IndexT, 0)?; + Ok(&self.buffer[offset..]) } } diff --git a/aeron-rs/src/command/terminate_driver.rs b/aeron-rs/src/command/terminate_driver.rs index 2782e93..b1a37fd 100644 --- a/aeron-rs/src/command/terminate_driver.rs +++ b/aeron-rs/src/command/terminate_driver.rs @@ -2,7 +2,7 @@ use crate::command::correlated_message::CorrelatedMessageDefn; use crate::command::flyweight::Flyweight; use crate::concurrent::AtomicBuffer; -use crate::util::IndexT; +use crate::util::{IndexT, Result}; use std::mem::size_of; /// Raw command to terminate a driver. The `token_length` describes the length @@ -41,58 +41,48 @@ where self } - /// Get the current length of the payload associated with this termination request. + /// Return the token buffer length pub fn token_length(&self) -> i32 { self.get_struct().token_length } - /// Set the payload length of this termination request. - /// - /// NOTE: While there are no safety issues, improperly setting this value can cause panics. - /// The `token_length` value is automatically set during calls to `put_token_buffer()`, - /// so this method is not likely to be frequently used. - pub fn put_token_length(&mut self, value: i32) -> &mut Self { - self.get_struct_mut().token_length = value; - self - } - /// Return the current token payload associated with this termination request. pub fn token_buffer(&self) -> &[u8] { - // QUESTION: Should I be slicing the buffer to `token_length`? - // C++ doesn't do anything, so I'm going to assume not. - &self.bytes_at(size_of::() as IndexT) + // UNWRAP: Size check performed during initialization + &self + .bytes_at(size_of::() as IndexT) + .unwrap()[..self.get_struct().token_length as usize] } /// Append a payload to the termination request. - pub fn put_token_buffer(&mut self, token_buffer: &[u8]) -> &mut Self { + pub fn put_token_buffer(&mut self, token_buffer: &[u8]) -> Result<&mut Self> { let token_length = token_buffer.len() as i32; - self.get_struct_mut().token_length = token_length; - if token_length > 0 { - // FIXME: Unwrap is unjustified here - // Currently just assume that people are going to be nice about the token buffer - // and not oversize it. C++ relies on throwing an exception if bounds are violated. - self.buffer - .put_slice( - size_of::() as IndexT, - &token_buffer, - 0, - token_length, - ) - .unwrap() + self.buffer.put_slice( + size_of::() as IndexT, + &token_buffer, + 0, + token_length, + )? } - self + self.get_struct_mut().token_length = token_length; + Ok(self) } /// Get the total byte length of this termination command pub fn length(&self) -> IndexT { - size_of::() as IndexT + self.token_length() + size_of::() as IndexT + self.get_struct().token_length } } #[cfg(test)] mod tests { + use crate::command::correlated_message::CorrelatedMessageDefn; + use crate::command::flyweight::Flyweight; use crate::command::terminate_driver::TerminateDriverDefn; + use crate::concurrent::AtomicBuffer; + use crate::util::IndexT; + use std::mem::size_of; #[test] @@ -102,4 +92,34 @@ mod tests { size_of::() ) } + + #[test] + #[should_panic] + fn panic_on_invalid_length() { + // QUESTION: Should this failure condition be included in the docs? + let token_len = 1; + + // Can trigger panic if `token_length` contains a bad value during initialization + let mut bytes = &mut [0u8; size_of::()][..]; + // `token_length` stored immediately following the correlated message, this is + // how to calculate the offset + let token_length_offset = size_of::(); + + // When running inside a `should_panic` test, a failed test is one that returns at all + let put_result = bytes.put_i32(token_length_offset as IndexT, token_len); + if put_result.is_err() { + return; + } + + let flyweight = Flyweight::new::(bytes, 0); + if flyweight.is_err() { + return; + } + + let flyweight = flyweight.unwrap(); + if flyweight.token_length() != token_len { + return; + } + flyweight.token_buffer(); + } } diff --git a/aeron-rs/src/driver_proxy.rs b/aeron-rs/src/driver_proxy.rs index 3d3fdbf..b3d5e98 100644 --- a/aeron-rs/src/driver_proxy.rs +++ b/aeron-rs/src/driver_proxy.rs @@ -5,6 +5,7 @@ use crate::concurrent::ringbuffer::ManyToOneRingBuffer; use crate::concurrent::AtomicBuffer; use crate::control_protocol::ClientCommand; use crate::util::{AeronError, IndexT, Result}; +use std::mem::size_of; /// High-level interface for issuing commands to a media driver pub struct DriverProxy @@ -15,6 +16,8 @@ where client_id: i64, } +const COMMAND_BUFFER_SIZE: usize = 512; + impl DriverProxy where A: AtomicBuffer, @@ -43,12 +46,19 @@ where /// that will be available to the driver. pub fn terminate_driver(&mut self, token_buffer: Option<&[u8]>) -> Result<()> { let client_id = self.client_id; + if token_buffer.is_some() + && token_buffer.unwrap().len() + > (COMMAND_BUFFER_SIZE - size_of::()) + { + return Err(AeronError::InsufficientCapacity); + } self.write_command_to_driver(|buffer: &mut [u8], length: &mut IndexT| { - // UNWRAP: Buffer from `write_command` guaranteed to be long enough for `TerminateDriverDefn` + // UNWRAP: `TerminateDriverDefn` guaranteed to be smaller than `COMMAND_BUFFER_SIZE` let mut request = Flyweight::new::(buffer, 0).unwrap(); request.put_client_id(client_id).put_correlation_id(-1); - token_buffer.map(|b| request.put_token_buffer(b)); + // UNWRAP: Bounds check performed prior to attempting the write + token_buffer.map(|b| request.put_token_buffer(b).unwrap()); *length = request.length(); ClientCommand::TerminateDriver @@ -61,7 +71,7 @@ where { // QUESTION: Can Rust align structs on stack? // C++ does some fancy shenanigans I assume help the CPU cache? - let mut buffer = &mut [0u8; 512][..]; + let mut buffer = &mut [0u8; COMMAND_BUFFER_SIZE][..]; let mut length = buffer.len() as IndexT; let msg_type_id = filler(&mut buffer, &mut length); diff --git a/aeron-rs/src/lib.rs b/aeron-rs/src/lib.rs index 63da9a0..7b59652 100644 --- a/aeron-rs/src/lib.rs +++ b/aeron-rs/src/lib.rs @@ -9,8 +9,8 @@ pub mod command; pub mod concurrent; pub mod context; pub mod control_protocol; -pub mod media_driver; pub mod driver_proxy; +pub mod media_driver; pub mod util; const fn sematic_version_compose(major: u8, minor: u8, patch: u8) -> i32 {