From f151d86604fddc22b2d72389c633172d94b2260a Mon Sep 17 00:00:00 2001 From: Bradlee Speice Date: Fri, 6 Sep 2019 20:56:57 -0400 Subject: [PATCH] Use traits for reader/writer Should make it easier to write analysis code --- src/capnp_runner.rs | 88 ++++++++++++++++++++++----------------- src/flatbuffers_runner.rs | 35 +++++----------- src/main.rs | 35 ++++++++++++++-- src/marketdata_sbe.rs | 3 ++ src/sbe_runner.rs | 25 ++++------- 5 files changed, 104 insertions(+), 82 deletions(-) diff --git a/src/capnp_runner.rs b/src/capnp_runner.rs index 167671f..da6992c 100644 --- a/src/capnp_runner.rs +++ b/src/capnp_runner.rs @@ -10,22 +10,11 @@ use capnp::serialize_packed::{read_message as read_message_packed, write_message use nom::bytes::complete::take_until; use nom::IResult; -use crate::{StreamVec, Summarizer}; +use crate::{RunnerDeserialize, RunnerSerialize, StreamVec, Summarizer}; use crate::iex::{IexMessage, IexPayload}; use crate::marketdata_capnp::{multi_message, Side}; use crate::marketdata_capnp::message; -fn __take_until<'a>(tag: &'static str, input: &'a [u8]) -> IResult<&'a [u8], &'a [u8]> { - take_until(tag)(input) -} - -fn parse_symbol(sym: &[u8; 8]) -> &str { - // TODO: Use the `jetscii` library for all that SIMD goodness - // IEX guarantees ASCII, so we're fine using an unsafe conversion - let (_, sym_bytes) = __take_until(" ", &sym[..]).unwrap(); - unsafe { from_utf8_unchecked(sym_bytes) } -} - pub struct CapnpWriter<'a> { // We have to be very careful with how messages are built, as running // `init_root` and rebuilding will still accumulate garbage if using @@ -33,10 +22,11 @@ pub struct CapnpWriter<'a> { // https://github.com/capnproto/capnproto-rust/issues/111 words: Vec, scratch: ScratchSpace<'a>, + packed: bool, } impl<'a> CapnpWriter<'a> { - pub fn new() -> CapnpWriter<'a> { + pub fn new(packed: bool) -> CapnpWriter<'a> { // Cap'n'Proto words are 8 bytes, MTU is 1500 bytes, theoretically need only 188 words. // In practice, let's just make sure everything fits. let mut words = capnp::Word::allocate_zeroed_vec(1024); @@ -48,6 +38,7 @@ impl<'a> CapnpWriter<'a> { CapnpWriter { words, scratch, + packed } } @@ -60,8 +51,10 @@ impl<'a> CapnpWriter<'a> { std::mem::transmute(&mut self.scratch) })) } +} - pub fn serialize(&mut self, payload: &IexPayload, mut output: &mut Vec, packed: bool) { +impl<'a> RunnerSerialize for CapnpWriter<'a> { + fn serialize(&mut self, payload: &IexPayload, mut output: &mut Vec) { // First, count the messages we actually care about. let num_msgs = payload.messages.iter().map(|m| { match m { @@ -76,14 +69,14 @@ impl<'a> CapnpWriter<'a> { // And actually serialize the IEX payload to CapNProto format + // This is the safe builder used for testing + //let mut builder = capnp::message::Builder::new_default(); + //let mut multimsg = builder.init_root::(); + // This is the unsafe (but faster) version let mut builder = self.builder(); let mut multimsg = builder.init_root::(); - // And the safe version used for testing - //let mut builder = capnp::message::Builder::new_default(); - //let mut multimsg = builder.init_root::(); - multimsg.set_seq_no(payload.first_seq_no); let mut messages = multimsg.init_messages(num_msgs as u32); @@ -95,7 +88,7 @@ impl<'a> CapnpWriter<'a> { current_msg_no += 1; message.set_ts(tr.timestamp); - let sym = parse_symbol(&tr.symbol); + let sym = crate::parse_symbol(&tr.symbol); message.reborrow().init_symbol(sym.len() as u32); message.set_symbol(sym); @@ -108,7 +101,7 @@ impl<'a> CapnpWriter<'a> { current_msg_no += 1; message.set_ts(plu.timestamp); - let sym = parse_symbol(&plu.symbol); + let sym = crate::parse_symbol(&plu.symbol); message.reborrow().init_symbol(sym.len() as u32); message.set_symbol(sym); @@ -122,28 +115,33 @@ impl<'a> CapnpWriter<'a> { } } - let write_fn = if packed { write_message_packed } else { write_message }; + let write_fn = if self.packed { write_message_packed } else { write_message }; write_fn(&mut output, &builder).unwrap(); } } pub struct CapnpReader { - read_opts: ReaderOptions + read_opts: ReaderOptions, + packed: bool } impl CapnpReader { - pub fn new() -> CapnpReader { + pub fn new(packed: bool) -> CapnpReader { CapnpReader { - read_opts: ReaderOptions::new() + read_opts: ReaderOptions::new(), + packed } } +} - pub fn deserialize_packed(&self, buf: &mut StreamVec, stats: &mut Summarizer) -> Result<(), Error> { +impl CapnpReader { + fn deserialize_packed<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { // Because `capnp::serialize_packed::PackedRead` is hidden from us, packed reads // *have* to both allocate new segments every read, and copy the buffer into // those same segments, no ability to re-use allocated memory. - let reader = read_message_packed(buf, self.read_opts)?; + let reader = read_message_packed(buf, self.read_opts) + .map_err(|_| ())?; let multimsg = reader.get_root::().unwrap(); for msg in multimsg.get_messages().unwrap().iter() { @@ -166,10 +164,10 @@ impl CapnpReader { Ok(()) } - pub fn deserialize_unpacked(&self, buf: &mut StreamVec, stats: &mut Summarizer) -> Result<(), Error> { - let mut data = buf.fill_buf()?; + fn deserialize_unpacked(&self, buf: &mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { + let mut data = buf.fill_buf().map_err(|_| ())?; if data.len() == 0 { - return Err(capnp::Error::failed(String::new())); + return Err(()); } let orig_data = data; @@ -179,7 +177,7 @@ impl CapnpReader { Read into `OwnedSegments`, which means we copy the entire message into a new Vec. Note that the `data` pointer is modified underneath us, can figure out the message length by checking the difference between where we started and what `data` is afterward. - This is a trick you learn only by checking the fuzzing test cases. + This is a trick you learn only by looking at the fuzzing test cases. let reader = capnp::serialize::read_message(&mut data, reader_opts)?; let bytes_consumed = orig_data.len() - data.len(); @@ -190,11 +188,12 @@ impl CapnpReader { but still forces a Vec allocation for `offsets`. Also requires us to copy code from Cap'n'Proto because `SliceSegments` has private fields, and `read_segment_table` is private. And all this because `read_segment_from_words` has a length check - that triggers an error if our buffer is too large. - There is no other documentation on how to calculate `bytes_consumed` in this case + that triggers an error if our buffer is too large. What the hell? + There is no documentation on how to calculate `bytes_consumed` when parsing by hand that I could find, you just have to guess and check until you figure this one out. */ - let (num_words, offsets) = read_segment_table(&mut data, reader_opts)?; + let (num_words, offsets) = read_segment_table(&mut data, reader_opts) + .map_err(|_| ())?; let words = unsafe { capnp::Word::bytes_to_words(data) }; let reader = capnp::message::Reader::new( SliceSegments { @@ -207,18 +206,19 @@ impl CapnpReader { let msg_bytes = num_words * size_of::(); let bytes_consumed = segment_table_bytes + msg_bytes; - let multimsg = reader.get_root::()?; - for msg in multimsg.get_messages()?.iter() { - let sym = msg.get_symbol()?; + let multimsg = reader.get_root::() + .map_err(|_| ())?; + for msg in multimsg.get_messages().map_err(|_| ())?.iter() { + let sym = msg.get_symbol().map_err(|_| ())?; - match msg.which()? { + match msg.which().map_err(|_| ())? { message::Trade(trade) => { let trade = trade.unwrap(); stats.append_trade_volume(sym, trade.get_size().into()); }, message::Quote(quote) => { let quote = quote.unwrap(); - let is_buy = match quote.get_side()? { + let is_buy = match quote.get_side().unwrap() { Side::Buy => true, _ => false }; @@ -232,6 +232,18 @@ impl CapnpReader { } } +impl RunnerDeserialize for CapnpReader { + fn deserialize<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { + // While this is an extra branch per call, we're going to assume that the overhead + // is essentially nil in practice + if self.packed { + self.deserialize_packed(buf, stats) + } else { + self.deserialize_unpacked(buf, stats) + } + } +} + pub struct SliceSegments<'a> { words: &'a [capnp::Word], diff --git a/src/flatbuffers_runner.rs b/src/flatbuffers_runner.rs index d7de476..74d3561 100644 --- a/src/flatbuffers_runner.rs +++ b/src/flatbuffers_runner.rs @@ -7,21 +7,10 @@ use capnp::data::new_builder; use flatbuffers::buffer_has_identifier; use nom::{bytes::complete::take_until, IResult}; -use crate::{StreamVec, Summarizer}; +use crate::{RunnerDeserialize, RunnerSerialize, StreamVec, Summarizer}; use crate::iex::{IexMessage, IexPayload}; use crate::marketdata_generated::md_shootout; -fn __take_until<'a>(tag: &'static str, input: &'a [u8]) -> IResult<&'a [u8], &'a [u8]> { - take_until(tag)(input) -} - -fn parse_symbol(sym: &[u8; 8]) -> &str { - // TODO: Use the `jetscii` library for all that SIMD goodness - // IEX guarantees ASCII, so we're fine using an unsafe conversion - let (_, sym_bytes) = __take_until(" ", &sym[..]).unwrap(); - unsafe { from_utf8_unchecked(sym_bytes) } -} - pub struct FlatbuffersWriter<'a> { builder: flatbuffers::FlatBufferBuilder<'a>, message_buffer: Vec>>, @@ -34,8 +23,10 @@ impl<'a> FlatbuffersWriter<'a> { message_buffer: Vec::new(), } } +} - pub fn serialize(&mut self, payload: &IexPayload, output: &mut Vec) { +impl<'a> RunnerSerialize for FlatbuffersWriter<'a> { + fn serialize(&mut self, payload: &IexPayload, output: &mut Vec) { // Because FlatBuffers can't handle nested vectors (specifically, we can't track // both the variable-length vector of messages, and the variable-length strings @@ -46,8 +37,7 @@ impl<'a> FlatbuffersWriter<'a> { let msg_args = match iex_msg { IexMessage::TradeReport(tr) => { // The `Args` objects used are wrappers over an underlying `Builder`. - // We trust release builds to optimize out the wrapper, but would be - // interesting to know whether that's actually the case. + // We trust release builds to optimize out the wrapper. let trade = md_shootout::Trade::create( &mut self.builder, &md_shootout::TradeArgs { @@ -56,18 +46,11 @@ impl<'a> FlatbuffersWriter<'a> { }, ); - /* - let mut trade_builder = md_shootout::TradeBuilder::new(self.builder); - trade_builder.add_price(tr.price); - trade_builder.add_size_(tr.size); - let trade = trade_builder.finish(); - */ - let sym_str = self.builder.create_string(parse_symbol(&tr.symbol)); + let sym_str = self.builder.create_string(crate::parse_symbol(&tr.symbol)); Some(md_shootout::MessageArgs { ts_nanos: tr.timestamp, symbol: Some(sym_str), body_type: md_shootout::MessageBody::Trade, - // Why the hell do I need the `as_union_value` function to convert to UnionWIPOffset??? body: Some(trade.as_union_value()), }) } @@ -82,7 +65,7 @@ impl<'a> FlatbuffersWriter<'a> { }, ); - let sym_str = self.builder.create_string(parse_symbol(&plu.symbol)); + let sym_str = self.builder.create_string(crate::parse_symbol(&plu.symbol)); Some(md_shootout::MessageArgs { ts_nanos: plu.timestamp, symbol: Some(sym_str), @@ -124,8 +107,10 @@ impl FlatbuffersReader { pub fn new() -> FlatbuffersReader { FlatbuffersReader {} } +} - pub fn deserialize<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { +impl RunnerDeserialize for FlatbuffersReader { + fn deserialize<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { // Flatbuffers has kinda ad-hoc support for streaming: https://github.com/google/flatbuffers/issues/3898 // Essentially, you can write an optional `u32` value to the front of each message // (`finish_size_prefixed` above) to figure out how long that message actually is. diff --git a/src/main.rs b/src/main.rs index c2e0d6a..838fd2c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,13 @@ use std::hash::Hasher; use std::io::{BufRead, Read}; use std::io::Error; use std::path::Path; -use std::time::SystemTime; +use std::str::from_utf8_unchecked; +use std::time::{Instant, SystemTime}; use clap::{App, Arg}; +use nom::{bytes::complete::take_until, IResult}; -use crate::iex::IexParser; +use crate::iex::{IexParser, IexPayload}; // Cap'n'Proto and Flatbuffers typically ask that you generate code on the fly to match // the schemas. For purposes of auto-complete and easy browsing in the repository, @@ -68,10 +70,14 @@ fn main() { while let Ok(_) = capnp_reader.deserialize_packed(&mut read_buf, &mut summarizer) { parsed_msgs += 1; } + */ let mut fb_writer = flatbuffers_runner::FlatbuffersWriter::new(); for iex_payload in parser { + let now = Instant::now(); fb_writer.serialize(&iex_payload, &mut output_buf); + let serialize_nanos = Instant::now().duration_since(now).as_nanos(); + dbg!(serialize_nanos); } let mut read_buf = StreamVec::new(output_buf); @@ -81,13 +87,15 @@ fn main() { while let Ok(_) = fb_reader.deserialize(&mut read_buf, &mut summarizer) { parsed_msgs += 1; } - */ /* let mut capnp_writer = capnp_runner::CapnpWriter::new(); for iex_payload in parser { //let iex_payload = parser.next().unwrap(); + let now = Instant::now(); capnp_writer.serialize(&iex_payload, &mut output_buf, false); + let serialize_nanos = Instant::now().duration_since(now).as_nanos(); + dbg!(serialize_nanos); } let capnp_reader = capnp_runner::CapnpReader::new(); @@ -98,6 +106,7 @@ fn main() { } */ + /* let mut sbe_writer = sbe_runner::SBEWriter::new(); for iex_payload in parser { //let iex_payload = parser.next().unwrap(); @@ -110,6 +119,7 @@ fn main() { while let Ok(_) = sbe_reader.deserialize(&mut read_buf, &mut summarizer) { parsed_msgs += 1; } + */ dbg!(parsed_msgs); dbg!(summarizer); @@ -197,3 +207,22 @@ impl BufRead for StreamVec { self.pos += amt; } } + +trait RunnerSerialize { + fn serialize(&mut self, payload: &IexPayload, output: &mut Vec); +} + +trait RunnerDeserialize { + fn deserialize<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()>; +} + +fn __take_until<'a>(tag: &'static str, input: &'a [u8]) -> IResult<&'a [u8], &'a [u8]> { + take_until(tag)(input) +} + +fn parse_symbol(sym: &[u8; 8]) -> &str { + // TODO: Use the `jetscii` library for all that SIMD goodness + // IEX guarantees ASCII, so we're fine using an unsafe conversion + let (_, sym_bytes) = __take_until(" ", &sym[..]).unwrap(); + unsafe { from_utf8_unchecked(sym_bytes) } +} diff --git a/src/marketdata_sbe.rs b/src/marketdata_sbe.rs index 4fac90e..e6c98b7 100644 --- a/src/marketdata_sbe.rs +++ b/src/marketdata_sbe.rs @@ -404,6 +404,7 @@ pub struct MultiMessageFieldsDecoder<'d> { scratch: ScratchDecoderData<'d>, } impl<'d> MultiMessageFieldsDecoder<'d> { + pub fn wrap(scratch: ScratchDecoderData<'d>) -> MultiMessageFieldsDecoder<'d> { MultiMessageFieldsDecoder { scratch: scratch } } @@ -418,6 +419,7 @@ pub struct MultiMessageMessageHeaderDecoder<'d> { scratch: ScratchDecoderData<'d>, } impl<'d> MultiMessageMessageHeaderDecoder<'d> { + pub fn wrap(scratch: ScratchDecoderData<'d>) -> MultiMessageMessageHeaderDecoder<'d> { MultiMessageMessageHeaderDecoder { scratch: scratch } } @@ -518,6 +520,7 @@ pub struct MultiMessageFieldsEncoder<'d> { scratch: ScratchEncoderData<'d>, } impl<'d> MultiMessageFieldsEncoder<'d> { + pub fn wrap(scratch: ScratchEncoderData<'d>) -> MultiMessageFieldsEncoder<'d> { MultiMessageFieldsEncoder { scratch: scratch } } diff --git a/src/sbe_runner.rs b/src/sbe_runner.rs index 9765a81..890f763 100644 --- a/src/sbe_runner.rs +++ b/src/sbe_runner.rs @@ -4,21 +4,10 @@ use std::str::from_utf8_unchecked; use nom::bytes::complete::take_until; use nom::IResult; -use crate::{marketdata_sbe, StreamVec, Summarizer}; +use crate::{marketdata_sbe, RunnerDeserialize, RunnerSerialize, StreamVec, Summarizer}; use crate::iex::{IexMessage, IexPayload}; use crate::marketdata_sbe::{Either, MultiMessageFields, MultiMessageMessageHeader, MultiMessageMessagesMember, MultiMessageMessagesMemberEncoder, MultiMessageMessagesSymbolEncoder, Side, start_decoding_multi_message, start_encoding_multi_message}; -fn __take_until<'a>(tag: &'static str, input: &'a [u8]) -> IResult<&'a [u8], &'a [u8]> { - take_until(tag)(input) -} - -fn parse_symbol(sym: &[u8; 8]) -> &str { - // TODO: Use the `jetscii` library for all that SIMD goodness - // IEX guarantees ASCII, so we're fine using an unsafe conversion - let (_, sym_bytes) = __take_until(" ", &sym[..]).unwrap(); - unsafe { from_utf8_unchecked(sym_bytes) } -} - pub struct SBEWriter { /// Buffer to construct messages before copying. While SBE benefits /// from easily being able to create messages directly in output buffer, @@ -38,8 +27,10 @@ impl SBEWriter { default_header: MultiMessageMessageHeader::default(), } } +} - pub fn serialize(&mut self, payload: &IexPayload, output: &mut Vec) { +impl RunnerSerialize for SBEWriter { + fn serialize(&mut self, payload: &IexPayload, output: &mut Vec) { let (fields, encoder) = start_encoding_multi_message(&mut self.scratch_buffer[..]) .header_copy(&self.default_header.message_header).unwrap() .multi_message_fields().unwrap(); @@ -59,7 +50,7 @@ impl SBEWriter { ..Default::default() }; let sym_enc: MultiMessageMessagesSymbolEncoder = enc.next_messages_member(&fields).unwrap(); - sym_enc.symbol(parse_symbol(&tr.symbol).as_bytes()).unwrap() + sym_enc.symbol(crate::parse_symbol(&tr.symbol).as_bytes()).unwrap() } IexMessage::PriceLevelUpdate(plu) => { let fields = MultiMessageMessagesMember { @@ -74,7 +65,7 @@ impl SBEWriter { ..Default::default() }; let sym_enc: MultiMessageMessagesSymbolEncoder = enc.next_messages_member(&fields).unwrap(); - sym_enc.symbol(parse_symbol(&plu.symbol).as_bytes()).unwrap() + sym_enc.symbol(crate::parse_symbol(&plu.symbol).as_bytes()).unwrap() } _ => enc } @@ -93,8 +84,10 @@ impl SBEReader { pub fn new() -> SBEReader { SBEReader {} } +} - pub fn deserialize<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { +impl RunnerDeserialize for SBEReader { + fn deserialize<'a>(&self, buf: &'a mut StreamVec, stats: &mut Summarizer) -> Result<(), ()> { let data = buf.fill_buf().unwrap(); if data.len() == 0 { return Err(());