From d1ff31b084c500aa4e76dbfe8d78f460ecfd0ecd Mon Sep 17 00:00:00 2001 From: Lee Smet Date: Wed, 15 Nov 2023 15:46:50 +0100 Subject: [PATCH] Clsoe #62: Make topic part of init packet Signed-off-by: Lee Smet --- docs/api.yaml | 19 ++++- src/api.rs | 99 ++++++++++++++++++------ src/main.rs | 179 +++++++++++++++++++------------------------- src/message.rs | 89 ++++++++++++++++------ src/message/init.rs | 23 +++++- 5 files changed, 259 insertions(+), 150 deletions(-) diff --git a/docs/api.yaml b/docs/api.yaml index 6e847c8..a2f038c 100644 --- a/docs/api.yaml +++ b/docs/api.yaml @@ -56,12 +56,13 @@ paths: a message if present, or return immediately if there isn't example: 60 - in: query - name: filter + name: topic required: false schema: type: string + format: byte minLength: 0 - maxLength: 255 + maxLength: 340 description: | Optional filter for loading messages. If set, the system checks if the message has the given string at the start. This way a topic can be encoded. @@ -217,6 +218,13 @@ components: minLength: 64 maxLength: 64 example: 02468ace13579bdf02468ace13579bdf02468ace13579bdf02468ace13579bdf + topic: + description: An optional message topic + type: string + format: byte + minLength: 0 + maxLength: 340 + example: hpV+ payload: description: The message payload, encoded in standard alphabet base64 type: string @@ -229,6 +237,13 @@ components: properties: dst: $ref: '#/components/schemas/MessageDestination' + topic: + description: An optional message topic + type: string + format: byte + minLength: 0 + maxLength: 340 + example: hpV+ payload: description: The message to send, base64 encoded type: string diff --git a/src/api.rs b/src/api.rs index be7d8a8..b64b955 100644 --- a/src/api.rs +++ b/src/api.rs @@ -39,7 +39,11 @@ struct HttpServerState { #[serde(rename_all = "camelCase")] pub struct MessageSendInfo { pub dst: MessageDestination, - #[serde(with = "base64")] + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "base64::optional_binary")] + pub topic: Option>, + #[serde(with = "base64::binary")] pub payload: Vec, } @@ -58,7 +62,11 @@ pub struct MessageReceiveInfo { pub src_pk: PublicKey, pub dst_ip: IpAddr, pub dst_pk: PublicKey, - #[serde(with = "base64")] + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "base64::optional_binary")] + pub topic: Option>, + #[serde(with = "base64::binary")] pub payload: Vec, } @@ -102,8 +110,11 @@ impl Http { struct GetMessageQuery { peek: Option, timeout: Option, - /// Optional filter for start of the message. - filter: Option, + /// Optional filter for start of the message, base64 encoded. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "base64::optional_binary")] + topic: Option>, } impl GetMessageQuery { @@ -127,14 +138,13 @@ async fn get_message( query.peek(), query.timeout_secs() ); + // A timeout of 0 seconds essentially means get a message if there is one, and return // immediatly if there isn't. This is the result of the implementation of Timeout, which does a // poll of the internal future first, before polling the delay. tokio::time::timeout( Duration::from_secs(query.timeout_secs()), - state - .message_stack - .message(!query.peek(), query.filter.map(String::into_bytes)), + state.message_stack.message(!query.peek(), query.topic), ) .await .or(Err(StatusCode::NO_CONTENT)) @@ -145,6 +155,11 @@ async fn get_message( src_pk: m.src_pk, dst_ip: m.dst_ip, dst_pk: m.dst_pk, + topic: if m.topic.is_empty() { + None + } else { + Some(m.topic) + }, payload: m.data, }) }) @@ -160,8 +175,8 @@ pub struct MessageIdReply { #[serde(rename_all = "camelCase")] #[serde(untagged)] pub enum PushMessageResponse { - Id(MessageIdReply), Reply(MessageReceiveInfo), + Id(MessageIdReply), } #[derive(Deserialize)] @@ -192,12 +207,22 @@ async fn push_message( message_info.payload.len(), ); - let (id, sub) = state.message_stack.new_message( + let (id, sub) = match state.message_stack.new_message( dst, message_info.payload, + if let Some(topic) = message_info.topic { + topic + } else { + vec![] + }, DEFAULT_MESSAGE_TRY_DURATION, query.await_reply(), - ); + ) { + Ok((id, sub)) => (id, sub), + Err(_) => { + return Err(StatusCode::BAD_REQUEST); + } + }; if !query.await_reply() { // If we don't wait for the reply just return here. @@ -219,6 +244,7 @@ async fn push_message( src_pk: m.src_pk, dst_ip: m.dst_ip, dst_pk: m.dst_pk, + topic: if m.topic.is_empty() { None } else { Some(m.topic.clone()) }, payload: m.data.clone(), })))) } else { @@ -277,24 +303,55 @@ async fn message_status( mod base64 { use base64::alphabet; use base64::engine::{GeneralPurpose, GeneralPurposeConfig}; - use base64::Engine; - use serde::{Deserialize, Serialize}; - use serde::{Deserializer, Serializer}; const B64ENGINE: GeneralPurpose = base64::engine::general_purpose::GeneralPurpose::new( &alphabet::STANDARD, GeneralPurposeConfig::new(), ); - pub fn serialize(v: &Vec, s: S) -> Result { - let base64 = B64ENGINE.encode(v); - String::serialize(&base64, s) + pub mod binary { + use super::B64ENGINE; + use base64::Engine; + use serde::{Deserialize, Serialize}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(v: &Vec, s: S) -> Result { + let base64 = B64ENGINE.encode(v); + String::serialize(&base64, s) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let base64 = String::deserialize(d)?; + B64ENGINE + .decode(base64.as_bytes()) + .map_err(serde::de::Error::custom) + } } - pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { - let base64 = String::deserialize(d)?; - B64ENGINE - .decode(base64.as_bytes()) - .map_err(serde::de::Error::custom) + pub mod optional_binary { + use super::B64ENGINE; + use base64::Engine; + use serde::{Deserialize, Serialize}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(v: &Option>, s: S) -> Result { + if let Some(v) = v { + let base64 = B64ENGINE.encode(v); + String::serialize(&base64, s) + } else { + >::serialize(&None, s) + } + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result>, D::Error> { + if let Some(base64) = >::deserialize(d)? { + B64ENGINE + .decode(base64.as_bytes()) + .map_err(serde::de::Error::custom) + .map(Option::Some) + } else { + Ok(None) + } + } } } diff --git a/src/main.rs b/src/main.rs index f0d5ce0..2a46237 100644 --- a/src/main.rs +++ b/src/main.rs @@ -385,11 +385,14 @@ enum Payload { #[serde(rename_all = "camelCase")] struct CliMessage { id: MessageId, - topic: Option, src_ip: IpAddr, src_pk: PublicKey, dst_ip: IpAddr, dst_pk: PublicKey, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(serialize_with = "serialize_payload")] + topic: Option, + #[serde(skip_serializing_if = "Option::is_none")] #[serde(serialize_with = "serialize_payload")] payload: Option, } @@ -407,16 +410,9 @@ fn serialize_payload(p: &Option, s: S) -> Result>::serialize(&base64, s) } -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct ReadableCliMessage { - id: MessageId, - topic: Option, - src_ip: IpAddr, - src_pk: PublicKey, - dst_ip: IpAddr, - dst_pk: PublicKey, - payload: String, +/// Encode arbitrary data in standard base64. +pub fn encode_base64(input: &[u8]) -> String { + B64ENGINE.encode(input) } /// Send a message to a receiver. @@ -478,29 +474,8 @@ async fn send_msg( } }; - // Load msg, files have prio. If a topic is present, include that first - // The layout of these messages (in binary) is: - // - 1 byte topic length - // - topic - // - actual message - // - // Meaning a message without topic has length 0. - let mut msg_buf = if let Some(topic) = topic { - if topic.len() > 255 { - error!("{topic} is longer than the maximum allowed topic length of 255"); - return Err( - std::io::Error::new(std::io::ErrorKind::InvalidInput, "Topic too long").into(), - ); - } - let mut tmp = Vec::with_capacity(topic.len() + 1); - tmp.push(topic.len() as u8); - tmp.extend_from_slice(topic.as_bytes()); - tmp - } else { - vec![0; 1] - }; - - msg_buf.extend_from_slice(&if let Some(path) = msg_path { + // Load msg, files have prio. + let msg = if let Some(path) = msg_path { match tokio::fs::read(&path).await { Err(e) => { error!("Could not read file at {:?}: {e}", path); @@ -517,7 +492,7 @@ async fn send_msg( "Message is a required argument if `--msg-path` is not provided", ) .into()); - }); + }; let mut url = format!("http://{server_addr}/api/v1/messages"); if let Some(reply_to) = reply_to { @@ -526,15 +501,15 @@ async fn send_msg( if wait { // A year should be sufficient to wait let reply_timeout = timeout.unwrap_or(60 * 60 * 24 * 365); - url.push_str("?reply_timeout="); - url.push_str(&format!("{reply_timeout}")); + url.push_str(&format!("?reply_timeout={reply_timeout}")); } match reqwest::Client::new() .post(url) .json(&MessageSendInfo { dst: destination, - payload: msg_buf, + topic: topic.map(String::into_bytes), + payload: msg, }) .send() .await @@ -543,55 +518,57 @@ async fn send_msg( error!("Failed to send request: {e}"); return Err(e.into()); } - Ok(res) => match res.json::().await { - Err(e) => { - error!("Failed to load response body {e}"); - return Err(e.into()); + Ok(res) => { + if res.status() == STATUSCODE_NO_CONTENT { + return Ok(()); } - Ok(resp) => { - match resp { - PushMessageResponse::Id(id) => { - let _ = serde_json::to_writer(std::io::stdout(), &id); - } - PushMessageResponse::Reply(mri) => { - let filter_len = mri.payload[0] as usize; - let cm = CliMessage { - id: mri.id, - topic: if filter_len == 1 { - None - } else { - Some( - String::from_utf8(mri.payload[1..filter_len].to_vec()) - .map_err(|e| { - error!("Failed to parse topic, not valid UTF-8 ({e})"); - e - })?, - ) - }, - src_ip: mri.src_ip, - src_pk: mri.src_pk, - dst_ip: mri.dst_ip, - dst_pk: mri.dst_pk, - payload: Some({ - let p = mri.payload[filter_len..].to_vec(); - if let Ok(s) = String::from_utf8(p.clone()) { - Payload::Readable(s) - } else { - Payload::NotReadable(p) - } - }), - }; - let _ = serde_json::to_writer(std::io::stdout(), &cm); - } + match res.json::().await { + Err(e) => { + error!("Failed to load response body {e}"); + return Err(e.into()); + } + Ok(resp) => { + match resp { + PushMessageResponse::Id(id) => { + let _ = serde_json::to_writer(std::io::stdout(), &id); + } + PushMessageResponse::Reply(mri) => { + let cm = CliMessage { + id: mri.id, + + topic: mri.topic.map(|topic| { + if let Ok(s) = String::from_utf8(topic.clone()) { + Payload::Readable(s) + } else { + Payload::NotReadable(topic) + } + }), + src_ip: mri.src_ip, + src_pk: mri.src_pk, + dst_ip: mri.dst_ip, + dst_pk: mri.dst_pk, + payload: Some({ + if let Ok(s) = String::from_utf8(mri.payload.clone()) { + Payload::Readable(s) + } else { + Payload::NotReadable(mri.payload) + } + }), + }; + let _ = serde_json::to_writer(std::io::stdout(), &cm); + } + } + println!(); } - println!(); } - }, + } } Ok(()) } +const STATUSCODE_NO_CONTENT: u16 = 204; + async fn recv_msg( timeout: Option, topic: Option, @@ -601,25 +578,27 @@ async fn recv_msg( ) -> Result<(), Box> { // One year timeout should be sufficient let timeout = timeout.unwrap_or(60 * 60 * 24 * 365); - let mut url = format!("http:{server_addr}/api/v1/messages?timeout={timeout}"); - let filter_len = if let Some(ref filter) = topic { - if filter.len() > 255 { - error!("{filter} is longer than the maximum allowed topic length of 255"); + let mut url = format!("http://{server_addr}/api/v1/messages?timeout={timeout}"); + if let Some(ref topic) = topic { + if topic.len() > 255 { + error!("{topic} is longer than the maximum allowed topic length of 255"); return Err( std::io::Error::new(std::io::ErrorKind::InvalidInput, "Topic too long").into(), ); } - url.push_str(&format!("&filter={filter}")); - filter.len() + 1 - } else { - 1 - }; + url.push_str(&format!("&topic={}", encode_base64(topic.as_bytes()))); + } let mut cm = match reqwest::get(url).await { Err(e) => { error!("Failed to wait for message: {e}"); return Err(e.into()); } Ok(resp) => { + if resp.status() == STATUSCODE_NO_CONTENT { + debug!("No message ready yet"); + return Ok(()); + } + debug!("Received message response"); match resp.json::().await { Err(e) => { @@ -628,28 +607,22 @@ async fn recv_msg( } Ok(mri) => CliMessage { id: mri.id, - topic: if filter_len == 1 { - None - } else { - Some( - String::from_utf8(mri.payload[1..filter_len].to_vec()).map_err( - |e| { - error!("Failed to parse topic, not valid UTF-8 ({e})"); - e - }, - )?, - ) - }, + topic: mri.topic.map(|topic| { + if let Ok(s) = String::from_utf8(topic.clone()) { + Payload::Readable(s) + } else { + Payload::NotReadable(topic) + } + }), src_ip: mri.src_ip, src_pk: mri.src_pk, dst_ip: mri.dst_ip, dst_pk: mri.dst_pk, payload: Some({ - let p = mri.payload[filter_len..].to_vec(); - if let Ok(s) = String::from_utf8(p.clone()) { + if let Ok(s) = String::from_utf8(mri.payload.clone()) { Payload::Readable(s) } else { - Payload::NotReadable(p) + Payload::NotReadable(mri.payload) } }), }, diff --git a/src/message.rs b/src/message.rs index 8045b07..41ef03b 100644 --- a/src/message.rs +++ b/src/message.rs @@ -78,8 +78,12 @@ const FLAG_MESSAGE_ACK: u16 = 0b0000_0001_0000_0000; /// Length of a message checksum in bytes. const MESSAGE_CHECKSUM_LENGTH: usize = 32; +/// Checksum of a message used to verify received message integrity. pub type Checksum = [u8; MESSAGE_CHECKSUM_LENGTH]; +/// Response type when pushing a message. +pub type MessagePushResponse = (MessageId, Option>>); + #[derive(Clone)] pub struct MessageStack { // The DataPlane is wrappen in a Mutex since it does not implement Sync. @@ -116,6 +120,8 @@ struct ReceivedMessageInfo { dst: IpAddr, /// Length of the finished message. len: u64, + /// Optional topic of the message. + topic: Vec, chunks: Vec>, } @@ -133,6 +139,8 @@ pub struct ReceivedMessage { pub dst_ip: IpAddr, /// The public key of the receiver of the message. This is always ours. pub dst_pk: PublicKey, + /// The possible topic of the message. + pub topic: Vec, /// Actual message. pub data: Vec, } @@ -181,6 +189,12 @@ enum TransmissionState { Aborted, } +#[derive(Debug, Clone, Copy)] +pub enum PushMessageError { + /// The topic set in the message is too large. + TopicTooLarge, +} + impl MessageInbox { fn new(notify: watch::Sender<()>) -> Self { Self { @@ -374,6 +388,7 @@ impl MessageStack { src, dst, len: mi.length(), + topic: mi.topic().into(), chunks, }; @@ -468,6 +483,7 @@ impl MessageStack { id: inbound_message.id, src: inbound_message.src, dst: inbound_message.dst, + topic: inbound_message.topic.clone(), data: message_data, }; @@ -500,6 +516,7 @@ impl MessageStack { src_pk: src_pubkey, dst_ip: message.dst, dst_pk: dst_pubkey, + topic: message.topic, data: message.data, }; @@ -578,10 +595,11 @@ impl MessageStack { &self, dst: IpAddr, data: Vec, + topic: Vec, try_duration: Duration, subscribe_reply: bool, - ) -> (MessageId, Option>>) { - self.push_message(None, dst, data, try_duration, subscribe_reply) + ) -> Result { + self.push_message(None, dst, data, topic, try_duration, subscribe_reply) } /// Push a new message which is a reply to the message with [the provided id](MessageId). @@ -592,7 +610,8 @@ impl MessageStack { data: Vec, try_duration: Duration, ) -> MessageId { - self.push_message(Some(reply_to), dst, data, try_duration, false) + self.push_message(Some(reply_to), dst, data, vec![], try_duration, false) + .expect("Empty topic is never too large") .0 } @@ -616,9 +635,14 @@ impl MessageStack { id: Option, dst: IpAddr, data: Vec, + topic: Vec, try_duration: Duration, subscribe: bool, - ) -> (MessageId, Option>>) { + ) -> Result { + if topic.len() > 255 { + return Err(PushMessageError::TopicTooLarge); + } + let src = self .data_plane .lock() @@ -635,7 +659,13 @@ impl MessageStack { }; let len = data.len(); - let msg = Message { id, src, dst, data }; + let msg = Message { + id, + src, + dst, + topic, + data, + }; let created = std::time::SystemTime::now(); let deadline = created + try_duration; @@ -655,12 +685,7 @@ impl MessageStack { None }; - self.outbox - .lock() - .expect("Outbox lock isn't poisoned; qed") - .insert(obmi); - - // Already send the init packet. + // Already prepare the init packet for sending.. let mut mp = MessagePacket::new(PacketBuffer::new()); mp.header_mut().set_message_id(id); if reply { @@ -669,6 +694,14 @@ impl MessageStack { let mut mi = MessageInit::new(mp); mi.set_length(len as u64); + mi.set_topic(&obmi.msg.topic); + + self.outbox + .lock() + .expect("Outbox lock isn't poisoned; qed") + .insert(obmi); + + // Actually send the init packet match (src, dst) { (IpAddr::V6(src), IpAddr::V6(dst)) => { self.data_plane.lock().unwrap().inject_message_packet( @@ -711,6 +744,7 @@ impl MessageStack { let mut mi = MessageInit::new(mp); mi.set_length(len as u64); + mi.set_topic(&msg.msg.topic); match (msg.msg.src, msg.msg.dst) { (IpAddr::V6(src), IpAddr::V6(dst)) => { message_stack @@ -898,7 +932,7 @@ impl MessageStack { } }); - (id, subscription) + Ok((id, subscription)) } /// Get information about the status of an outbound message. @@ -948,7 +982,7 @@ impl MessageStack { /// /// If pop is false, the message is not removed and the next call of this method will return /// the same message. - pub async fn message(&self, pop: bool, filter: Option>) -> ReceivedMessage { + pub async fn message(&self, pop: bool, topic: Option>) -> ReceivedMessage { // Copy the subscriber since we need mutable access to it. let mut subscriber = self.subscriber.clone(); @@ -958,15 +992,12 @@ impl MessageStack { 'check: { let mut inbox = self.inbox.lock().unwrap(); // If a filter is set only check for those messages. - if let Some(ref filter) = filter { - if let Some((idx, _)) = - inbox.complete_msges.iter().enumerate().find(|(_, v)| { - if v.data.len() < filter.len() + 1 { - return false; - } - v.data[0] == filter.len() as u8 - && v.data[1..filter.len() + 1] == filter[..] - }) + if let Some(ref topic) = topic { + if let Some((idx, _)) = inbox + .complete_msges + .iter() + .enumerate() + .find(|(_, v)| &v.topic == topic) { return inbox.complete_msges.remove(idx).unwrap(); } else { @@ -1381,7 +1412,9 @@ pub struct Message { src: IpAddr, /// Destination IP dst: IpAddr, - /// Data + /// An optional topic of the message, usefull to differentiate messages before reading. + topic: Vec, + /// Data of the message data: Vec, } @@ -1412,6 +1445,16 @@ impl Message { } } +impl fmt::Display for PushMessageError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::TopicTooLarge => f.write_str("topic too large, topic is limitted to 255 bytes"), + } + } +} + +impl std::error::Error for PushMessageError {} + #[cfg(test)] mod tests { diff --git a/src/message/init.rs b/src/message/init.rs index 609a82b..39f718c 100644 --- a/src/message/init.rs +++ b/src/message/init.rs @@ -11,7 +11,7 @@ pub struct MessageInit { impl MessageInit { /// Create a new `MessageInit` in the provided [`MessagePacket`]. pub fn new(mut buffer: MessagePacket) -> Self { - buffer.set_used_buffer_size(8); + buffer.set_used_buffer_size(9); buffer.header_mut().flags_mut().set_init(); Self { buffer } } @@ -25,11 +25,32 @@ impl MessageInit { ) } + /// Return the topic of the message, as written in the body. + pub fn topic(&self) -> &[u8] { + let topic_len = self.buffer.buffer()[8] as usize; + &self.buffer.buffer()[9..9 + topic_len] + } + /// Set the length field of the message body. pub fn set_length(&mut self, length: u64) { self.buffer.buffer_mut()[..8].copy_from_slice(&length.to_be_bytes()) } + /// Set the topic in the message body. + /// + /// # Panics + /// + /// This function panics if the topic is longer than 255 bytes. + pub fn set_topic(&mut self, topic: &[u8]) { + assert!( + topic.len() <= u8::MAX as usize, + "Topic can be 255 bytes long at most" + ); + self.buffer.set_used_buffer_size(9 + topic.len()); + self.buffer.buffer_mut()[8] = topic.len() as u8; + self.buffer.buffer_mut()[9..9 + topic.len()].copy_from_slice(topic); + } + /// Convert the `MessageInit` into a reply. This does nothing if it is already a reply. pub fn into_reply(mut self) -> Self { self.buffer.header_mut().flags_mut().set_ack();