diff --git a/src/codec.rs b/src/codec.rs new file mode 100644 index 0000000..66c8ae7 --- /dev/null +++ b/src/codec.rs @@ -0,0 +1,406 @@ +use crate::packet::{ + BabelPacketBody, BabelPacketHeader, BabelTLV, BabelTLVType, ControlPacket, DataPacket, Packet, + PacketType, +}; +use bytes::{Buf, BufMut, BytesMut}; +use std::{ + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, +}; +use tokio_util::codec::{Decoder, Encoder}; + +/* ********************************PAKCET*********************************** */ +pub struct PacketCodec { + packet_type: Option, + data_packet_codec: DataPacketCodec, + control_packet_codec: ControlPacketCodec, +} + +impl PacketCodec { + pub fn new() -> Self { + PacketCodec { + packet_type: None, + data_packet_codec: DataPacketCodec::new(), + control_packet_codec: ControlPacketCodec::new(), + } + } +} + +impl Decoder for PacketCodec { + type Item = Packet; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + // Determine the packet_type + let packet_type = if let Some(packet_type) = self.packet_type { + packet_type + } else { + // Check we can read the packet type (1 byte) + if src.is_empty() { + return Ok(None); + } + + let packet_type_byte = src.get_u8(); // ! This will advance the buffer 1 byte ! + let packet_type = match packet_type_byte { + 0 => PacketType::DataPacket, + 1 => PacketType::ControlPacket, + _ => { + println!("buffer: {:?}", &src[..src.remaining()]); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid packet type", + )); + } + }; + + self.packet_type = Some(packet_type); + + packet_type + }; + + // Decode packet based on determined packet_type + match packet_type { + PacketType::DataPacket => { + match self.data_packet_codec.decode(src) { + Ok(Some(p)) => { + self.packet_type = None; // Reset state + Ok(Some(Packet::DataPacket(p))) + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } + PacketType::ControlPacket => { + match self.control_packet_codec.decode(src) { + Ok(Some(p)) => { + self.packet_type = None; // Reset state + Ok(Some(Packet::ControlPacket(p))) + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } + } + } +} + +impl Encoder for PacketCodec { + type Error = std::io::Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + match item { + Packet::DataPacket(datapacket) => { + dst.put_u8(0); + self.data_packet_codec.encode(datapacket, dst) + } + Packet::ControlPacket(controlpacket) => { + dst.put_u8(1); + self.control_packet_codec.encode(controlpacket, dst) + } + } + } +} + +/* ******************************DATA PACKET********************************* */ +pub struct DataPacketCodec { + len: Option, + dest_ip: Option, +} + +impl DataPacketCodec { + pub fn new() -> Self { + DataPacketCodec { + len: None, + dest_ip: None, + } + } +} + +impl Decoder for DataPacketCodec { + type Item = DataPacket; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + // Determine the length of the data + let data_len = if let Some(data_len) = self.len { + data_len + } else { + // Check we have enough data to decode + if src.len() < 2 { + return Ok(None); + } + + let data_len = src.get_u16(); + self.len = Some(data_len); + + data_len + } as usize; + + // Determine the destination IP + let dest_ip = if let Some(dest_ip) = self.dest_ip { + dest_ip + } else { + if src.len() < 4 { + return Ok(None); + } + + // Decode octets + let mut ip_bytes = [0u8; 4]; + ip_bytes.copy_from_slice(&src[..4]); + let dest_ip = Ipv4Addr::from(ip_bytes); + src.advance(4); + + self.dest_ip = Some(dest_ip); + dest_ip + }; + + // Check we have enough data to decode + if src.len() < data_len { + return Ok(None); + } + + // Decode octets + let mut data = vec![0u8; data_len]; + data.copy_from_slice(&src[..data_len]); + src.advance(data_len); + + // Reset state + self.len = None; + self.dest_ip = None; + + Ok(Some(DataPacket { + raw_data: data, + dest_ip, + })) + } +} + +impl Encoder for DataPacketCodec { + type Error = std::io::Error; + + fn encode(&mut self, item: DataPacket, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.reserve(item.raw_data.len() + 6); + // Write the length of the data + dst.put_u16(item.raw_data.len() as u16); + // Write the destination IP + dst.put_slice(&item.dest_ip.octets()); + // Write the data + dst.extend_from_slice(&item.raw_data); + + Ok(()) + } +} + +/* ****************************CONTROL PACKET******************************** */ +pub struct ControlPacketCodec { + header: Option, +} + +impl ControlPacketCodec { + pub fn new() -> Self { + ControlPacketCodec { header: None } + } +} + +// TODO FUTURE-WISE --> HANDLE BUFFER READS THAT MIGHT NOT HAVE ARRIVED YET +impl Decoder for ControlPacketCodec { + type Item = ControlPacket; + type Error = std::io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { + let header = if let Some(header) = self.header.take() { + header + } else { + if buf.remaining() < 4 { + return Ok(None); + } + + let magic = buf.get_u8(); + let version = buf.get_u8(); + let body_length = buf.get_u16(); + + BabelPacketHeader { + magic, + version, + body_length, + } + }; + + if buf.remaining() < header.body_length as usize { + // here the self.header is actually always None (due to take function) + // so assign it again to Some(header) + self.header = Some(header); + return Ok(None); + } + + let tlv_type = match BabelTLVType::from_u8(buf.get_u8()) { + Some(t) => t, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid TLV type", + )) + } + }; + + let length = buf.get_u8(); + + let body = match tlv_type { + BabelTLVType::Hello => { + let seqno = buf.get_u16(); + let interval = buf.get_u16(); + + BabelPacketBody { + tlv_type, + length, + tlv: BabelTLV::Hello { seqno, interval }, + } + } + BabelTLVType::IHU => { + let interval = buf.get_u16(); + let address = IpAddr::V4(Ipv4Addr::new( + buf.get_u8(), + buf.get_u8(), + buf.get_u8(), + buf.get_u8(), + )); + + BabelPacketBody { + tlv_type, + length, + tlv: BabelTLV::IHU { interval, address }, + } + } + BabelTLVType::Update => { + let ae = buf.get_u8(); + let plen = buf.get_u8(); + let interval = buf.get_u16(); + let seqno = buf.get_u16(); + let metric = buf.get_u16(); + // based on the remaining bytes (ip + router_id) we can check if it's IPv4 or v6 + let prefix = match ae { + 0 => { + // 4 bytes IP + 8 bytes router_id + IpAddr::V4(Ipv4Addr::new( + buf.get_u8(), + buf.get_u8(), + buf.get_u8(), + buf.get_u8(), + )) + } + 1 => { + // 16 bytes IP + 8 bytes router_id + IpAddr::V6(Ipv6Addr::new( + buf.get_u16(), + buf.get_u16(), + buf.get_u16(), + buf.get_u16(), + buf.get_u16(), + buf.get_u16(), + buf.get_u16(), + buf.get_u16(), + )) + } + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid address length", + )) + } + }; + let router_id = buf.get_u64(); + + BabelPacketBody { + tlv_type, + length, + tlv: BabelTLV::Update { + plen, + interval, + seqno, + metric, + prefix, + router_id, + }, + } + } + BabelTLVType::AckReq => todo!(), + BabelTLVType::Ack => todo!(), + BabelTLVType::NextHop => todo!(), + BabelTLVType::RouteReq => todo!(), + BabelTLVType::SeqnoReq => todo!(), + }; + + Ok(Some(ControlPacket { header, body })) + } +} + +impl Encoder for ControlPacketCodec { + type Error = io::Error; + + fn encode(&mut self, message: ControlPacket, buf: &mut BytesMut) -> Result<(), Self::Error> { + // Write BabelPacketHeader + buf.put_u8(message.header.magic); + buf.put_u8(message.header.version); + buf.put_u16(message.header.body_length); + + // Write BabelPacketBody + buf.put_u8(message.body.tlv_type as u8); + buf.put_u8(message.body.length); + + match message.body.tlv { + BabelTLV::Hello { seqno, interval } => { + buf.put_u16(seqno); + buf.put_u16(interval); + } + BabelTLV::IHU { interval, address } => { + buf.put_u16(interval); + match address { + IpAddr::V4(ipv4) => { + buf.put_u8(ipv4.octets()[0]); + buf.put_u8(ipv4.octets()[1]); + buf.put_u8(ipv4.octets()[2]); + buf.put_u8(ipv4.octets()[3]); + } + IpAddr::V6(_ipv6) => { + println!("IPv6 not supported yet"); + } + } + } + BabelTLV::Update { + plen, + interval, + seqno, + metric, + prefix, + router_id, + } => { + buf.put_u8(if prefix.is_ipv4() { 0 } else { 1 }); + buf.put_u8(plen); + buf.put_u16(interval); + buf.put_u16(seqno); + buf.put_u16(metric); + match prefix { + IpAddr::V4(ipv4) => { + buf.put_u8(ipv4.octets()[0]); + buf.put_u8(ipv4.octets()[1]); + buf.put_u8(ipv4.octets()[2]); + buf.put_u8(ipv4.octets()[3]); + } + IpAddr::V6(_ipv6) => { + buf.put_u16(_ipv6.segments()[0]); + buf.put_u16(_ipv6.segments()[1]); + buf.put_u16(_ipv6.segments()[2]); + buf.put_u16(_ipv6.segments()[3]); + buf.put_u16(_ipv6.segments()[4]); + buf.put_u16(_ipv6.segments()[5]); + buf.put_u16(_ipv6.segments()[6]); + buf.put_u16(_ipv6.segments()[7]); + } + } + buf.put_u64(router_id); + } // Add encoding logic for other TLV types. + } + + Ok(()) + } +} diff --git a/src/main.rs b/src/main.rs index 70dfc3e..9a2e424 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,28 +1,31 @@ -use std::{error::Error, net::Ipv4Addr}; - +use crate::packet::DataPacket; +use crate::router::StaticRoute; use bytes::BytesMut; use clap::Parser; -use etherparse::{InternetSlice, IpHeader, PacketHeaders, SlicedPacket}; -use packet_control::{DataPacket, Packet, PacketCodec}; -use tokio::{io::AsyncReadExt, net::TcpListener, sync::mpsc}; +use etherparse::{IpHeader, PacketHeaders}; +use std::{ + error::Error, + net::{Ipv4Addr, SocketAddr}, +}; +use tokio::io::AsyncBufReadExt; +mod codec; mod node_setup; -mod packet_control; +mod packet; mod peer; mod peer_manager; -mod routing; +mod router; +mod routing_table; +mod source_table; -use peer::Peer; -use peer_manager::PeerManager; -use tokio::io::AsyncWriteExt; -use tokio_util::codec::Encoder; - -const LINK_MTU: usize = 1500; +const LINK_MTU: usize = 1420; #[derive(Parser)] struct Cli { #[arg(short = 'a', long = "tun-addr")] tun_addr: Ipv4Addr, + #[arg(short = 'p', long = "peers", num_args = 1..)] + static_peers: Vec, } #[tokio::main] @@ -40,178 +43,120 @@ async fn main() -> Result<(), Box> { } }; - // Create an unbounded channel for this node - let (to_tun, mut from_routing) = mpsc::unbounded_channel::(); - let (to_routing, mut from_node) = mpsc::unbounded_channel::(); + let static_peers = cli.static_peers; - // Create the PeerManager: an interface to all peers this node is connected to - // Additional static peers are obtained through the nodeconfig.toml file - let peer_manager = PeerManager::new(); - - // Create static peers from the nodeconfig.toml file - let peer_man_clone = peer_manager.clone(); - let to_routing_clone = to_routing.clone(); - tokio::spawn(async move { - peer_man_clone - .get_peers_from_config(to_routing_clone, cli.tun_addr) - .await; // --> here we create peer by TcpStream connect - }); - - let peer_man_clone = peer_manager.clone(); - let to_routing_clone = to_routing.clone(); - // listen for inbound request --> "to created the reverse peer object" --> here we reverse create peer be listener.accept'ing - tokio::spawn(async move { - match TcpListener::bind("[::]:9651").await { - Ok(listener) => { - // loop to accept the inbound requests - loop { - let to_routing_clone_clone = to_routing_clone.clone(); - match listener.accept().await { - Ok((mut stream, _)) => { - // TEMPORARY: as we do not work with Babel yet, we will send to overlay ip (= addr of TUN) manually - // The packet flow looks like this: - // Listener accepts a TCP connect call here and send it's overlay IP over the stream - // In the peer_manager.rs at the place where we are connected we should manually add the overlay IP to the peer instance - - // 1. Send own TUN address over the stream - let ip_bytes = cli.tun_addr.octets(); - stream.write_all(&ip_bytes).await.unwrap(); - - // 4. Read other node's TUN address from the stream - let mut buffer = [0u8; 4]; - stream.read_exact(&mut buffer).await.unwrap(); - let received_overlay_ip = Ipv4Addr::from(buffer); - println!( - "Received overlay IP from other node: {:?}", - received_overlay_ip - ); - - // "reverse peer add" - let peer_stream_ip = stream.peer_addr().unwrap().ip(); - match Peer::new( - peer_stream_ip, - to_routing_clone_clone, - stream, - received_overlay_ip, - ) { - Ok(new_peer) => { - //println!("adding new peer to known_peers: {:?}", new_peer); - peer_man_clone.known_peers.lock().unwrap().push(new_peer); - } - Err(e) => { - eprintln!("Error creating 'reverse' peer: {}", e); - } - } - } - Err(e) => { - eprintln!("Error accepting TCP listener: {}", e); - } - } - } - } - Err(e) => { - eprintln!("Error binding TCP listener: {}", e); - } + // Creating a new Router instance + let router = match router::Router::new( + node_tun.clone(), + vec![StaticRoute::new(cli.tun_addr.into())], + ) { + Ok(router) => { + println!("Router created. ID: {}", router.router_id()); + router } - }); + Err(e) => { + panic!("Error creating router: {}", e); + } + }; - // Loop to read the 'from_routing' receiver and foward it toward the TUN interface - // TODO: you will only get DataPackets on TUN so the channel should only accept DataPackets (and not just Packet) - let node_tun_clone = node_tun.clone(); - tokio::spawn(async move { - loop { - while let Some(packet) = from_routing.recv().await { - let data_packet = if let Packet::DataPacket(p) = packet { - println!("LENTHEEE: {}", p.raw_data.len()); - p - } else { - continue; - }; - match node_tun_clone.send(&data_packet.raw_data).await { - Ok(_) => { - println!("Sending it towards this node's TUN"); + // Creating a new PeerManager instance + let _peer_manager: peer_manager::PeerManager = + peer_manager::PeerManager::new(router.clone(), static_peers); + + // Read packets from the TUN interface (originating from the kernel) and send them to the router + // Note: we will never receive control packets from the kernel, only data packets + { + let router = router.clone(); + let node_tun = node_tun.clone(); + + tokio::spawn(async move { + loop { + let mut buf = BytesMut::zeroed(LINK_MTU); + + match node_tun.recv(&mut buf).await { + Ok(n) => { + buf.truncate(n); } Err(e) => { - eprintln!("Error sending to TUN interface: {}", e); + eprintln!("Error reading from TUN: {}", e); + continue; } } - } - } - }); - // Loop to read from node's TUN interface and send it to to_routing sender halve - let node_tun_clone = node_tun.clone(); - let to_routing_clone = to_routing.clone(); - tokio::spawn(async move { - loop { - let mut buf = BytesMut::zeroed(LINK_MTU); - match node_tun_clone.recv(&mut buf).await { - Ok(n) => { - buf.truncate(n); - - println!("Got packet on my TUN, byyes: {}", n); - - // Remainder: if we read from TUN we will only need to parse them into DataPackets - // Extract the destination IP address using Etherparse - match PacketHeaders::from_ip_slice(&buf) { - Ok(packet) => { - if let Some(IpHeader::Version4(header, _)) = packet.ip { - let dest_addr = Ipv4Addr::from(header.destination); - println!("Destination IPv4 address: {}", dest_addr); - - let data_packet = DataPacket { - dest_ip: dest_addr, - raw_data: buf.to_vec(), - }; - - println!("LEN: {}", data_packet.raw_data.len()); - - match to_routing_clone.send(Packet::DataPacket(data_packet)) { - Ok(_) => { - println!("packet sent to to_routing"); - } - Err(e) => { - eprintln!("Error sending packet to to_routing: {}", e); - } - } - } else { - println!("Non-IPv4 packet received, ignoring..."); - } - } - Err(e) => { - println!("buffer: {:?}", buf); - eprintln!("Error from_ip_slice: {e}"); - } + let packet = match PacketHeaders::from_ip_slice(&buf) { + Ok(packet) => packet, + Err(e) => { + println!("buffer: {:?}", buf); + eprintln!("Error from_ip_slice: {}", e); + continue; } + }; + + if let Some(IpHeader::Version4(header, _)) = packet.ip { + let dest_addr = Ipv4Addr::from(header.destination); + println!("Destination IPv4 address: {}", dest_addr); + + let data_packet = DataPacket { + dest_ip: dest_addr, + raw_data: buf.to_vec(), + }; + if router.router_data_tx().send(data_packet).is_err() { + eprintln!("Failed to send data_packet"); + } + } else { + println!("Non-IPv4 packet received, ignoring..."); + } + } + }); + } + + let mut reader = tokio::io::BufReader::new(tokio::io::stdin()); + let mut line = String::new(); + + let read_handle = tokio::spawn(async move { + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => return, // EOF, quit + Ok(_) => { + // Remove trailing newline + line.pop(); + println!("----------- Current selected routes -----------{}\n", line); + router.print_selected_routes(); + println!("----------- Current fallback routes -----------{}\n", line); + router.print_fallback_routes(); + + println!("\n----------- Current peers: -----------"); + for p in router.peer_interfaces() { + println!( + "Peer: {:?}, with link cost: {}", + p.overlay_ip(), + p.link_cost() + ); + } + + println!("\n----------- Current source table: -----------"); + router.print_source_table(); + + println!("\n\n"); } Err(e) => { - eprintln!("Error reading from TUN: {}", e); + eprintln!("Error reading line: {}", e); + return; } } } }); - // Loop to read from from_node reeiver and route the packet further - // the route_packet function will send the packet towards the correct to_peer (based on dest ip of packet) - // or towards this own node's TUN interface (if dest ip of packet is this node's TUN addr) - let peer_man_clone = peer_manager.clone(); - let node_tun_clone = node_tun.clone(); - let to_tun_sender_clone = to_tun.clone(); - tokio::spawn(async move { - loop { - let node_tun_inner_clone = node_tun_clone.clone(); - let to_tun_sender_inner_clone = to_tun_sender_clone.clone(); - while let Some(packet) = from_node.recv().await { - //println!("Read message from from_node, sending it to route_packet function"); - peer_man_clone.route_packet( - packet, - node_tun_inner_clone.clone(), - to_tun_sender_inner_clone.clone(), - ); - } - } + let sleep_handle = tokio::spawn(async move { + // Just die after 1 day, you've probably leaked memory by then anyway + tokio::time::sleep(tokio::time::Duration::from_secs(60 * 60 * 24)).await; }); - tokio::time::sleep(std::time::Duration::from_secs(60 * 60 * 24)).await; + tokio::select! { + _ = read_handle => { /* The read task completed (this should never happen) */ } + _ = sleep_handle => { /* The sleep task completed (the program should exit here) */ } + } + Ok(()) } diff --git a/src/node_setup.rs b/src/node_setup.rs index 8cb70a9..f308b86 100644 --- a/src/node_setup.rs +++ b/src/node_setup.rs @@ -1,15 +1,11 @@ -use tokio_tun::{Tun, TunBuilder}; -use std::{ - sync::Arc, - net::Ipv4Addr, - error::Error, -}; -use rtnetlink::Handle; use futures::stream::TryStreamExt; +use rtnetlink::Handle; +use std::{error::Error, net::Ipv4Addr, sync::Arc}; +use tokio_tun::{Tun, TunBuilder}; pub const TUN_NAME: &str = "tun0"; pub const TUN_ROUTE_DEST: Ipv4Addr = Ipv4Addr::new(10, 0, 0, 0); -pub const TUN_ROUTE_PREFIX: u8 = 24; +pub const TUN_ROUTE_PREFIX: u8 = 16; // Create a TUN interface pub fn create_tun_interface(int_addr: Ipv4Addr) -> Result, Box> { @@ -54,31 +50,15 @@ pub async fn add_route(handle: Handle) -> Result<(), Box> { } pub async fn setup_node(tun_addr: Ipv4Addr) -> Result, Box> { - match create_tun_interface(tun_addr) { - Ok(tun) => { - println!("Interface '{}' ({}) created", TUN_NAME, tun_addr); - match rtnetlink::new_connection() { - Ok((conn, handle, _)) => { - tokio::spawn(conn); - match add_route(handle.clone()).await { - Ok(_) => { - println!("Static route created"); - }, - Err(e) => { - panic!("Error adding route: {}", e); - } - } - }, - Err(e) => { - panic!("Error creating new handle: {}", e); - } - } + let tun = create_tun_interface(tun_addr)?; + println!("Interface '{}' ({}) created", TUN_NAME, tun_addr); - - Ok(tun) - }, - Err(e) => { - panic!("Error creating TUN: {}", e); - } - } -} \ No newline at end of file + let (conn, handle, _) = rtnetlink::new_connection()?; + tokio::spawn(conn); + + add_route(handle.clone()).await?; + + println!("Static route created"); + + Ok(tun) +} diff --git a/src/packet.rs b/src/packet.rs new file mode 100644 index 0000000..89935e8 --- /dev/null +++ b/src/packet.rs @@ -0,0 +1,204 @@ +use std::net::{IpAddr, Ipv4Addr}; + +use crate::peer::Peer; + +pub const BABEL_MAGIC: u8 = 42; +pub const BABEL_VERSION: u8 = 2; + +/* ********************************PAKCET*********************************** */ +#[derive(Debug, Clone)] +pub enum Packet { + DataPacket(DataPacket), + ControlPacket(ControlPacket), +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum PacketType { + DataPacket = 0, + ControlPacket = 1, +} + +/* ******************************DATA PACKET********************************* */ +#[derive(Debug, Clone)] +pub struct DataPacket { + pub raw_data: Vec, + pub dest_ip: Ipv4Addr, +} + +impl DataPacket {} + +/* ****************************CONTROL PACKET******************************** */ + +#[derive(Debug, Clone)] +pub struct ControlStruct { + pub control_packet: ControlPacket, + pub src_overlay_ip: IpAddr, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct ControlPacket { + pub header: BabelPacketHeader, + pub body: BabelPacketBody, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct BabelPacketHeader { + pub magic: u8, + pub version: u8, + pub body_length: u16, // length of the whole BabelPacketBody (tlv_type, length and body) +} + +// A BabelPacketBody describes exactly one TLV +#[derive(Debug, PartialEq, Clone)] +pub struct BabelPacketBody { + pub tlv_type: BabelTLVType, + pub length: u8, // length of the tlv (only the tlv, not tlv_type and length itself) + pub tlv: BabelTLV, +} + +impl BabelPacketHeader { + pub fn new(body_length: u16) -> Self { + Self { + magic: BABEL_MAGIC, + version: BABEL_VERSION, + body_length, + } + } +} + +impl ControlPacket { + pub fn new_hello(dest_peer: &mut Peer, interval: u16) -> Self { + let header_length = (BabelTLVType::Hello.get_tlv_length(false) + 2) as u16; + dest_peer.increment_hello_seqno(); + Self { + header: BabelPacketHeader::new(header_length), + body: BabelPacketBody { + tlv_type: BabelTLVType::Hello, + length: BabelTLVType::Hello.get_tlv_length(false), + tlv: BabelTLV::Hello { + seqno: dest_peer.hello_seqno(), + interval, + }, + }, + } + } + + pub fn new_ihu(interval: u16, dest_address: IpAddr) -> Self { + let uses_ipv6 = dest_address.is_ipv6(); + let header_length = (BabelTLVType::IHU.get_tlv_length(uses_ipv6) + 2) as u16; + Self { + header: BabelPacketHeader::new(header_length), + body: BabelPacketBody { + tlv_type: BabelTLVType::IHU, + length: BabelTLVType::IHU.get_tlv_length(uses_ipv6), + tlv: BabelTLV::IHU { + interval, + address: dest_address, + }, + }, + } + } + + pub fn new_update( + plen: u8, + interval: u16, + seqno: u16, + metric: u16, + prefix: IpAddr, + router_id: u64, + ) -> Self { + let uses_ipv6 = prefix.is_ipv6(); + let header_length = (BabelTLVType::Update.get_tlv_length(uses_ipv6) + 2) as u16; + Self { + header: BabelPacketHeader::new(header_length), + body: BabelPacketBody { + tlv_type: BabelTLVType::Update, + length: BabelTLVType::Update.get_tlv_length(uses_ipv6), + tlv: BabelTLV::Update { + plen, + interval, + seqno, + metric, + prefix, + router_id, + }, + }, + } + } +} + +#[derive(Debug, PartialEq, Clone)] +pub enum BabelTLVType { + // Pad1 = 0, + // PadN = 1, + AckReq = 2, + Ack = 3, + Hello = 4, + IHU = 5, + // RouterID = 6, + NextHop = 7, + Update = 8, + RouteReq = 9, + SeqnoReq = 10, +} + +impl BabelTLVType { + pub fn from_u8(value: u8) -> Option { + match value { + // 0 => Some(Self::Pad1), + // 1 => Some(Self::PadN), + 2 => Some(Self::AckReq), + 3 => Some(Self::Ack), + 4 => Some(Self::Hello), + 5 => Some(Self::IHU), + // 6 => Some(Self::RouterID), + 7 => Some(Self::NextHop), + 8 => Some(Self::Update), + 9 => Some(Self::RouteReq), + 10 => Some(Self::SeqnoReq), + _ => None, + } + } + + pub fn get_tlv_length(self, uses_ipv6: bool) -> u8 { + let (ipv6, ipv4) = match self { + Self::AckReq => (4, 4), + Self::Ack => (2, 2), + Self::Hello => (4, 4), + Self::IHU => (18, 6), + Self::NextHop => (16, 4), + Self::Update => (31 + 1, 19 + 1), // +1 for ae + Self::RouteReq => (17, 5), + Self::SeqnoReq => (21, 9), + }; + if uses_ipv6 { + ipv6 + } else { + ipv4 + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum BabelTLV { + // These TLVs are not implemented as they are used for padding when sending multiple TLVs in one packet. + // Pad1, + // PadN(u8), + Hello { + seqno: u16, + interval: u16, + }, + IHU { + interval: u16, + address: IpAddr, + }, + Update { + plen: u8, + interval: u16, + seqno: u16, + metric: u16, + prefix: IpAddr, + router_id: u64, + }, +} diff --git a/src/packet_control.rs b/src/packet_control.rs deleted file mode 100644 index e308b72..0000000 --- a/src/packet_control.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::net::Ipv4Addr; - -use etherparse::{PacketHeaders, IpHeader}; -use tokio_util::codec::{Decoder, Encoder}; -use bytes::{BytesMut, Buf, BufMut}; - -#[derive(Clone)] -pub enum Packet { - DataPacket(DataPacket), // packet coming from kernel - //ControlPacket(ControlPacket), // babel related packets -} - -// create function to extract destip from Packet type -impl Packet { - pub fn get_dest_ip(&self) -> std::net::Ipv4Addr { - match self { - Packet::DataPacket(packet) => packet.dest_ip, - //Packet::ControlPacket(packet) => packet.dest_ip, - } - } -} - -#[derive(Clone)] -pub struct DataPacket { - pub raw_data: Vec, - pub dest_ip: std::net::Ipv4Addr, -} - -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum PacketType { - DataPacket = 0, - _ControlPacket = 1, -} - -pub struct PacketCodec { - packet_type: Option, - data_packet_codec: DataPacketCodec, - //control_packet_codec: ControlPacketCodec, -} - -impl PacketCodec { - pub fn new() -> Self { - PacketCodec {packet_type: None, data_packet_codec: DataPacketCodec::new()} - } -} - -pub struct DataPacketCodec { - len: Option, - dest_ip: Option, -} - -impl DataPacketCodec{ - pub fn new() -> Self { - DataPacketCodec { len: None , dest_ip: None } - } -} - - -impl Decoder for DataPacketCodec { - type Item = DataPacket; - type Error = std::io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - let data_len = if let Some(data_len) = self.len { - data_len - } else { - - // check we have enough data to decode - if src.len() < 2 { - return Ok(None); - } - - let data_len = src.get_u16(); - self.len = Some(data_len); - - data_len - } as usize; - - let dest_ip = if let Some(dest_ip) = self.dest_ip { - dest_ip - } else { - if src.len() < 4 { - return Ok(None); - } - - // decode octets - let mut ip_bytes = [0u8; 4]; - ip_bytes.copy_from_slice(&src[..4]); - let dest_ip = Ipv4Addr::from(ip_bytes); - src.advance(4); - - self.dest_ip = Some(dest_ip); - dest_ip - }; - - if src.len() < data_len { - - src.reserve(data_len - src.len()); - - return Ok(None); - } - - // we have enough data - let data = src[..data_len].to_vec(); - src.advance(data_len); - - // Reset state - self.len = None; - self.dest_ip = None; - - Ok(Some(DataPacket { raw_data: data, dest_ip })) - } -} - -impl Encoder for DataPacketCodec { - type Error = std::io::Error; - - fn encode(&mut self, item: DataPacket, dst: &mut BytesMut) -> Result<(), Self::Error> { - // implies that length is never more than u16 - - dst.reserve(item.raw_data.len() + 6); - dst.put_u16(item.raw_data.len() as u16); - // dest ip wegschrijven - dst.put_slice(&item.dest_ip.octets()); - - dst.extend_from_slice(&item.raw_data); - - - Ok(()) - } -} - -impl Decoder for PacketCodec { - type Item = Packet; - type Error = std::io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - let packet_type = if let Some(packet_type) = self.packet_type { - packet_type - } else { - - // Check if we have enough bytes to read one byte (which shows to packet type) - if src.len() < 1 { - return Ok(None); - } - - let raw_packet_type = src.get_u8(); // Beware: this advances src by 1 u8 - let packet_type = match raw_packet_type { - 0 => { PacketType::DataPacket } - // 1 => { PacketType::ControlPacket } - _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unrecognized packet type")) - }; - - packet_type - }; - - match packet_type { - PacketType::DataPacket => { - match self.data_packet_codec.decode(src) { - Ok(Some(p)) => { - self.packet_type = None; // necessary otherwise we would have the situation where assume the packet_type already exists and just read further - Ok(Some(Packet::DataPacket(p))) - }, - Ok(None) => { - Ok(None) - }, - Err(e) => { - Err(e) - } - } - } - PacketType::_ControlPacket => { - unimplemented!() - } - } - } -} - -impl Encoder for PacketCodec { - type Error = std::io::Error; - - fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { - match item { - Packet::DataPacket(datapacket) => { - dst.put_u8(0); - self.data_packet_codec.encode(datapacket, dst) - } - // PacketType::ControlPacket(controlpacket) => { - // dst.put_u8(1); - // self.control_packet.codec.encode(controlpacket); - // } - } - } -} - - diff --git a/src/peer.rs b/src/peer.rs index 87d9c36..1423185 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -1,73 +1,213 @@ use futures::{SinkExt, StreamExt}; -use std::{error::Error, net::{IpAddr, Ipv4Addr}}; -use tokio::{ - net::TcpStream, - select, - sync::{mpsc}, +use std::{ + error::Error, + net::IpAddr, + sync::{Arc, RwLock}, }; -use tokio_util::codec::{Framed, Decoder}; +use tokio::{net::TcpStream, select, sync::mpsc}; +use tokio_util::codec::Framed; -use crate::packet_control::{DataPacket, Packet, PacketCodec}; +use crate::packet::{ControlPacket, ControlStruct, DataPacket}; +use crate::{codec::PacketCodec, packet::Packet}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Peer { - pub stream_ip: IpAddr, - pub to_peer: mpsc::UnboundedSender, - pub overlay_ip: Ipv4Addr, + inner: Arc>, } impl Peer { - pub fn new(stream_ip: IpAddr, to_routing: mpsc::UnboundedSender, stream: TcpStream, overlay_ip: Ipv4Addr) -> Result> { + pub fn new( + stream_ip: IpAddr, + router_data_tx: mpsc::UnboundedSender, + router_control_tx: mpsc::UnboundedSender, + stream: TcpStream, + overlay_ip: IpAddr, + ) -> Result> { + Ok(Peer { + inner: Arc::new(RwLock::new(PeerInner::new( + stream_ip, + router_data_tx, + router_control_tx, + stream, + overlay_ip, + )?)), + }) + } - // Create a Framed for each peer + /// Get current sequence number for this peer. + pub fn hello_seqno(&self) -> u16 { + self.inner.read().unwrap().hello_seqno + } + + /// Adds 1 to the sequence number of this peer . + pub fn increment_hello_seqno(&self) { + self.inner.write().unwrap().hello_seqno += 1; + } + + pub fn time_last_received_hello(&self) -> tokio::time::Instant { + self.inner.read().unwrap().time_last_received_hello + } + + pub fn set_time_last_received_hello(&self, time: tokio::time::Instant) { + self.inner.write().unwrap().time_last_received_hello = time + } + + /// Get overlay IP for this peer + pub fn overlay_ip(&self) -> IpAddr { + self.inner.read().unwrap().overlay_ip + } + + /// For sending data packets towards a peer instance on this node. + /// It's send over the to_peer_data channel and read from the corresponding receiver. + /// The receiver sends the packet over the TCP stream towards the destined peer instance on another node + pub fn send_data_packet(&self, data_packet: DataPacket) -> Result<(), Box> { + Ok(self.inner.write().unwrap().to_peer_data.send(data_packet)?) + } + + /// For sending control packets towards a peer instance on this node. + /// It's send over the to_peer_control channel and read from the corresponding receiver. + /// The receiver sends the packet over the TCP stream towards the destined peer instance on another node + pub fn send_control_packet(&self, control_packet: ControlPacket) -> Result<(), Box> { + Ok(self + .inner + .write() + .unwrap() + .to_peer_control + .send(control_packet)?) + } + + pub fn link_cost(&self) -> u16 { + self.inner.read().unwrap().link_cost + } + + pub fn set_link_cost(&self, link_cost: u16) { + self.inner.write().unwrap().link_cost = link_cost + } + + pub fn underlay_ip(&self) -> IpAddr { + self.inner.read().unwrap().stream_ip + } + + pub fn time_last_received_ihu(&self) -> tokio::time::Instant { + self.inner.read().unwrap().time_last_received_ihu + } + + pub fn set_time_last_received_ihu(&self, time: tokio::time::Instant) { + self.inner.write().unwrap().time_last_received_ihu = time + } +} + +impl PartialEq for Peer { + fn eq(&self, other: &Self) -> bool { + self.overlay_ip() == other.overlay_ip() + } +} + +#[derive(Debug)] +struct PeerInner { + stream_ip: IpAddr, + to_peer_data: mpsc::UnboundedSender, + to_peer_control: mpsc::UnboundedSender, + overlay_ip: IpAddr, + hello_seqno: u16, + time_last_received_hello: tokio::time::Instant, + link_cost: u16, + time_last_received_ihu: tokio::time::Instant, +} + +impl PeerInner { + pub fn new( + stream_ip: IpAddr, + router_data_tx: mpsc::UnboundedSender, + router_control_tx: mpsc::UnboundedSender, + stream: TcpStream, + overlay_ip: IpAddr, + ) -> Result> { + // Framed for peer + // Used to send and receive packets from a TCP stream let mut framed = Framed::new(stream, PacketCodec::new()); - // Create an unbounded channel for each peer - let (to_peer, mut from_routing) = mpsc::unbounded_channel::(); + // Data channel for peer + let (to_peer_data, mut from_routing_data) = mpsc::unbounded_channel::(); + // Control channel for peer + let (to_peer_control, mut from_routing_control) = + mpsc::unbounded_channel::(); + + // Initialize last_sent_hello_seqno to 0 + let hello_seqno = 0; + // Initialize last_path_cost to infinity - 1 + let link_cost = u16::MAX - 1; + // Initialize time_last_received_hello to now + let time_last_received_hello = tokio::time::Instant::now(); + // Initialiwe time_last_send_ihu + let time_last_received_ihu = tokio::time::Instant::now(); + + // Intialize the timers + // let ihu_timer = Timer::new_ihu_timer(IHU_INTERVAL); tokio::spawn(async move { loop { select! { - // received from peer - frame = framed.next() => { - match frame { - Some(Ok(packet)) => { - // Send to TUN interface - // toekomst: nog een een tussenstap - println!("3: I'm the peer instance that got the message from the TCP stream"); - match packet { - Packet::DataPacket(packet) => { - if let Err(error) = to_routing.send(Packet::DataPacket(packet)){ - eprintln!("Error sending to TUN: {}", error); - } - + // Received over the TCP stream + frame = framed.next() => { + match frame { + Some(Ok(packet)) => { + match packet { + Packet::DataPacket(packet) => { + if let Err(error) = router_data_tx.send(packet){ + eprintln!("Error sending to to_routing_data: {}", error); } - // Packet::ControlPacket(packet) => { - // TODO: control packet - // } } + Packet::ControlPacket(packet) => { + // Parse the DataPacket into a ControlStruct as the to_routing_control channel expects + let control_struct = ControlStruct { + control_packet: packet, + src_overlay_ip: overlay_ip, + // Note: although this control packet is received from the TCP stream + // we set the src_overlay_ip to the overlay_ip of the peer + // as we 'arrived' in the peer instance of representing the sending node on this current node + }; + if let Err(error) = router_control_tx.send(control_struct) { + eprintln!("Error sending to to_routing_control: {}", error); + } - }, - Some(Err(e)) => { - eprintln!("Error from framed: {}", e); - }, - None => { - println!("Stream is closed."); - return + } } } - } - // receive from from_routing - Some(packet) = from_routing.recv() => { - println!("Receiver from from_routing, sending it over the TCP stream"); - // Send it over the TCP stream - if let Err(e) = framed.send(packet).await { - eprintln!("Error writing to stream: {}", e); + Some(Err(e)) => { + eprintln!("Error from framed: {}", e); + }, + None => { + println!("Stream is closed."); + return } } } + + Some(packet) = from_routing_data.recv() => { + // Send it over the TCP stream + if let Err(e) = framed.send(Packet::DataPacket(packet)).await { + eprintln!("Error writing to stream: {}", e); + } + } + + Some(packet) = from_routing_control.recv() => { + // Send it over the TCP stream + if let Err(e) = framed.send(Packet::ControlPacket(packet)).await { + eprintln!("Error writing to stream: {}", e); + } + } + } } }); - - Ok(Self { stream_ip, to_peer, overlay_ip }) + Ok(Self { + stream_ip, + to_peer_data, + to_peer_control, + overlay_ip, + hello_seqno, + link_cost, + time_last_received_ihu, + time_last_received_hello, + }) } } diff --git a/src/peer_manager.rs b/src/peer_manager.rs index ac8e2db..829ab6c 100644 --- a/src/peer_manager.rs +++ b/src/peer_manager.rs @@ -1,10 +1,10 @@ -use crate::{packet_control::Packet, peer::Peer}; +use crate::peer::Peer; +use crate::router::Router; use serde::Deserialize; -use std::net::{Ipv4Addr, SocketAddr}; -use std::sync::{Arc, Mutex}; +use std::net::{IpAddr, SocketAddr}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::{net::TcpStream, sync::mpsc::UnboundedSender}; -use tokio_tun::Tun; +use tokio::net::TcpListener; +use tokio::net::TcpStream; pub const NODE_CONFIG_FILE_PATH: &str = "nodeconfig.toml"; @@ -13,109 +13,271 @@ struct PeersConfig { peers: Vec, } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct PeerManager { - pub known_peers: Arc>>, + pub router: Router, + pub initial_peers: Vec, } impl PeerManager { - pub fn new() -> Self { - let mut known_peers: Vec = Vec::new(); + pub fn new(router: Router, static_peers_sockets: Vec) -> Self { + let peer_manager = PeerManager { + router, + initial_peers: static_peers_sockets.clone(), + }; + // Start a TCP listener. When a new connection is accepted, the reverse peer exchange is performed. + tokio::spawn(PeerManager::start_listener(peer_manager.clone())); + // Reads the nodeconfig.toml file and connects to the peers in the file. + tokio::spawn(PeerManager::get_peers_from_config(peer_manager.clone())); + // Remote nodes can also be read from CLI arg + tokio::spawn(PeerManager::get_peers_from_cli( + peer_manager.clone(), + static_peers_sockets, + )); - Self { - known_peers: Arc::new(Mutex::new(known_peers)), + tokio::spawn(PeerManager::reconnect_to_initial_peers( + peer_manager.clone(), + )); + + peer_manager + } + + async fn get_peers_from_config(self) { + if let Ok(file_content) = std::fs::read_to_string(NODE_CONFIG_FILE_PATH) { + let config: PeersConfig = toml::from_str(&file_content).unwrap(); + + for peer_addr in config.peers { + if let Ok(mut peer_stream) = TcpStream::connect(peer_addr).await { + let mut buffer = [0u8; 17]; + peer_stream.read_exact(&mut buffer).await.unwrap(); + let received_overlay_ip = match buffer[0] { + 0 => IpAddr::from( + <&[u8] as TryInto<[u8; 4]>>::try_into(&buffer[1..5]).unwrap(), + ), + 1 => IpAddr::from( + <&[u8] as TryInto<[u8; 16]>>::try_into(&buffer[1..]).unwrap(), + ), + _ => { + eprintln!("Invalid address encoding byte"); + continue; + } + }; + + println!( + "Received overlay IP from other node: {:?}", + received_overlay_ip + ); + + let mut buf = [0u8; 17]; + match self.router.node_tun_addr() { + IpAddr::V4(tun_addr) => { + buf[0] = 0; + buf[1..5].copy_from_slice(&tun_addr.octets()[..]); + } + IpAddr::V6(tun_addr) => { + buf[0] = 1; + buf[1..].copy_from_slice(&tun_addr.octets()[..]); + } + } + peer_stream.write_all(&buf).await.unwrap(); + + let peer_stream_ip = peer_addr.ip(); + if let Ok(new_peer) = Peer::new( + peer_stream_ip, + self.router.router_data_tx(), + self.router.router_control_tx(), + peer_stream, + received_overlay_ip, + ) { + self.router.add_peer_interface(new_peer); + } + } + } + } else { + eprintln!("Error reading nodeconfig.toml file"); } } - pub async fn get_peers_from_config( - &self, - to_routing: UnboundedSender, - tun_addr_own_node: Ipv4Addr, - ) { - // Read from the nodeconfig.toml file - match std::fs::read_to_string(NODE_CONFIG_FILE_PATH) { - Ok(file_content) => { - // Create a PeersConfig based on the file content - let config: PeersConfig = toml::from_str(&file_content).unwrap(); - for peer_addr in config.peers { - match TcpStream::connect(peer_addr).await { - Ok(mut peer_stream) => { - //println!("TCP stream connected: {}", peer_addr); + async fn get_peers_from_cli(self, socket_addresses: Vec) { + for peer_addr in socket_addresses { + println!("connecting to: {}", peer_addr); - // 2. Read other node's TUN address from the stream - let mut buffer = [0u8; 4]; - peer_stream.read_exact(&mut buffer).await.unwrap(); - let received_overlay_ip = Ipv4Addr::from(buffer); - println!( - "Received overlay IP from other node: {:?}", - received_overlay_ip - ); + if let Ok(mut peer_stream) = TcpStream::connect(peer_addr).await { + println!("stream established"); - // 3. Send own TUN address over the stream - let ip_bytes = tun_addr_own_node.octets(); - peer_stream.write_all(&ip_bytes).await.unwrap(); + let mut buffer = [0u8; 17]; + peer_stream.read_exact(&mut buffer).await.unwrap(); + let received_overlay_ip = match buffer[0] { + 0 => { + IpAddr::from(<&[u8] as TryInto<[u8; 4]>>::try_into(&buffer[1..5]).unwrap()) + } + 1 => { + IpAddr::from(<&[u8] as TryInto<[u8; 16]>>::try_into(&buffer[1..]).unwrap()) + } + _ => { + eprintln!("Invalid address encoding byte"); + continue; + } + }; + println!( + "3: Received overlay IP from other node: {:?}", + received_overlay_ip + ); - // Create peer instance - let peer_stream_ip = peer_addr.ip(); - match Peer::new( - peer_stream_ip, - to_routing.clone(), - peer_stream, - received_overlay_ip, - ) { - Ok(new_peer) => { - // Add peer to known_peers - let mut known_peers = self.known_peers.lock().unwrap(); - known_peers.push(new_peer); - } - Err(e) => { - eprintln!("Error creating peer: {}", e); - } + let mut buf = [0u8; 17]; + + match self.router.node_tun_addr() { + IpAddr::V4(tun_addr) => { + buf[0] = 0; + buf[1..5].copy_from_slice(&tun_addr.octets()[..]); + } + IpAddr::V6(tun_addr) => { + buf[0] = 1; + buf[1..].copy_from_slice(&tun_addr.octets()[..]); + } + } + + peer_stream.write_all(&buf).await.unwrap(); + + let peer_stream_ip = peer_addr.ip(); + if let Ok(new_peer) = Peer::new( + peer_stream_ip, + self.router.router_data_tx(), + self.router.router_control_tx(), + peer_stream, + received_overlay_ip, + ) { + self.router.add_peer_interface(new_peer); + } + } + } + } + + // this is used to reconnect to the provided static peers in case the connection is lost + async fn reconnect_to_initial_peers(self) { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + + // check if there is an entry for the peer in the router's peer list + for peer in self.initial_peers.iter() { + if !self.router.peer_exists(peer.ip()) { + if let Ok(mut peer_stream) = TcpStream::connect(peer).await { + let mut buffer = [0u8; 17]; + peer_stream.read_exact(&mut buffer).await.unwrap(); + let received_overlay_ip = match buffer[0] { + 0 => IpAddr::from( + <&[u8] as TryInto<[u8; 4]>>::try_into(&buffer[1..5]).unwrap(), + ), + 1 => IpAddr::from( + <&[u8] as TryInto<[u8; 16]>>::try_into(&buffer[1..]).unwrap(), + ), + _ => { + eprintln!("Invalid address encoding byte"); + continue; + } + }; + + println!( + "Received overlay IP from other node: {:?}", + received_overlay_ip + ); + + let mut buf = [0u8; 17]; + match self.router.node_tun_addr() { + IpAddr::V4(tun_addr) => { + buf[0] = 0; + buf[1..5].copy_from_slice(&tun_addr.octets()[..]); + } + IpAddr::V6(tun_addr) => { + buf[0] = 1; + buf[1..].copy_from_slice(&tun_addr.octets()[..]); } } - Err(e) => { - eprintln!( - "Error connecting to TCP stream for {}: {}", - peer_addr.to_string(), - e - ); + peer_stream.write_all(&buf).await.unwrap(); + + let peer_stream_ip = peer.ip(); + if let Ok(new_peer) = Peer::new( + peer_stream_ip, + self.router.router_data_tx(), + self.router.router_control_tx(), + peer_stream, + received_overlay_ip, + ) { + self.router.add_peer_interface(new_peer); } } } } + } + } + + async fn start_listener(self) { + match TcpListener::bind("[::]:9651").await { + Ok(listener) => loop { + match listener.accept().await { + Ok((stream, _)) => { + PeerManager::start_reverse_peer_exchange(stream, self.router.clone()).await; + } + Err(e) => { + eprintln!("Error accepting connection: {}", e); + } + } + }, Err(e) => { - eprintln!("Error reading nodeconfig.toml file: {}", e); + eprintln!("Error starting listener: {}", e); } } } - pub fn route_packet( - &self, - packet: Packet, - own_node_tun: Arc, - to_tun_sender: UnboundedSender, - ) { - // We first extract the IP from the Packet and look if the destination IP is our own overlay IP - // So if --> forward packet to our own TUN interface - // If not --> look in known_peers which peer's overlay_ip matches with destination IP + async fn start_reverse_peer_exchange(mut stream: TcpStream, router: Router) { + // Steps: + // 1. Send own TUN address over the stream + // 2. Read other node's TUN address from the stream - let packet_dest_ip = packet.get_dest_ip(); + let mut buf = [0u8; 17]; - // Packet towards own node's TUN interface - if packet_dest_ip == own_node_tun.address().unwrap() { - println!("Packet got address of our own TUN --> so sending it to my own TUN"); - to_tun_sender.send(packet); - // Packet towards other peer - } else { - let mut known_peers = self.known_peers.lock().unwrap(); - for peer in known_peers.iter_mut() { - if peer.overlay_ip == packet_dest_ip { - println!("Routing packet towards: {}", peer.overlay_ip.to_string()); - peer.to_peer.send(packet); - break; - } else { - println!("No peer match found"); - } + match router.node_tun_addr() { + IpAddr::V4(tun_addr) => { + buf[0] = 0; + buf[1..5].copy_from_slice(&tun_addr.octets()[..]); + } + IpAddr::V6(tun_addr) => { + buf[0] = 1; + buf[1..].copy_from_slice(&tun_addr.octets()[..]); + } + } + + stream.write_all(&buf).await.unwrap(); + + stream.read_exact(&mut buf).await.unwrap(); + let received_overlay_ip = match buf[0] { + 0 => IpAddr::from(<&[u8] as TryInto<[u8; 4]>>::try_into(&buf[1..5]).unwrap()), + 1 => IpAddr::from(<&[u8] as TryInto<[u8; 16]>>::try_into(&buf[1..]).unwrap()), + _ => { + eprintln!("Invalid address encoding byte"); + return; + } + }; + println!( + "Received overlay IP from other node: {:?}", + received_overlay_ip + ); + + // Create new Peer instance + let peer_stream_ip = stream.peer_addr().unwrap().ip(); + let new_peer = Peer::new( + peer_stream_ip, + router.router_data_tx(), + router.router_control_tx(), + stream, + received_overlay_ip, + ); + match new_peer { + Ok(new_peer) => { + router.add_peer_interface(new_peer); + } + Err(e) => { + eprintln!("Error creating peer: {}", e); } } } diff --git a/src/router.rs b/src/router.rs new file mode 100644 index 0000000..46780c7 --- /dev/null +++ b/src/router.rs @@ -0,0 +1,715 @@ +use crate::{ + packet::{BabelTLV, BabelTLVType, ControlPacket, ControlStruct, DataPacket}, + peer::Peer, + routing_table::{RouteEntry, RouteKey, RoutingTable}, + source_table::{self, FeasibilityDistance, SourceKey, SourceTable}, +}; +use rand::Rng; +use std::{ + error::Error, + fmt::Debug, + net::IpAddr, + sync::{Arc, RwLock}, +}; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use tokio_tun::Tun; + +const HELLO_INTERVAL: u16 = 4; +const IHU_INTERVAL: u16 = HELLO_INTERVAL * 3; +const UPDATE_INTERVAL: u16 = HELLO_INTERVAL * 4; + +#[derive(Debug, Clone, Copy)] +pub struct StaticRoute { + plen: u8, + prefix: IpAddr, + seqno: u16, +} + +impl StaticRoute { + pub fn new(prefix: IpAddr) -> Self { + Self { + plen: 32, + prefix, + seqno: 0, + } + } +} + +#[derive(Clone)] +pub struct Router { + inner: Arc>, +} + +impl Router { + pub fn new( + node_tun: Arc, + static_routes: Vec, + ) -> Result> { + // Tx is passed onto each new peer instance. This enables peers to send control packets to the router. + let (router_control_tx, router_control_rx) = mpsc::unbounded_channel::(); + // Tx is passed onto each new peer instance. This enables peers to send data packets to the router. + let (router_data_tx, router_data_rx) = mpsc::unbounded_channel::(); + + let router = Router { + inner: Arc::new(RwLock::new(RouterInner::new( + node_tun, + static_routes, + router_data_tx, + router_control_tx, + )?)), + }; + + tokio::spawn(Router::start_periodic_hello_sender(router.clone())); + tokio::spawn(Router::handle_incoming_control_packet( + router.clone(), + router_control_rx, + )); + tokio::spawn(Router::handle_incoming_data_packet( + router.clone(), + router_data_rx, + )); + tokio::spawn(Router::propagate_static_route(router.clone())); + tokio::spawn(Router::propagate_selected_routes(router.clone())); + + tokio::spawn(Router::check_for_dead_peers(router.clone())); + + Ok(router) + } + + pub fn router_id(&self) -> u64 { + self.inner.read().unwrap().router_id + } + + pub fn router_control_tx(&self) -> UnboundedSender { + self.inner.read().unwrap().router_control_tx.clone() + } + + pub fn router_data_tx(&self) -> UnboundedSender { + self.inner.read().unwrap().router_data_tx.clone() + } + + pub fn node_tun_addr(&self) -> IpAddr { + IpAddr::V4(self.inner.read().unwrap().node_tun.address().unwrap()) + } + + pub fn node_tun(&self) -> Arc { + self.inner.read().unwrap().node_tun.clone() + } + + pub fn router_seqno(&self) -> u16 { + self.inner.read().unwrap().router_seqno + } + + pub fn increment_router_seqno(&self) { + self.inner.write().unwrap().router_seqno += 1; + } + + pub fn peer_interfaces(&self) -> Vec { + self.inner.read().unwrap().peer_interfaces.clone() + } + + pub fn add_peer_interface(&self, peer: Peer) { + self.inner.write().unwrap().peer_interfaces.push(peer); + } + + pub fn remove_peer_interface(&self, peer: Peer) { + self.inner.write().unwrap().remove_peer_interface(peer); + } + + pub fn static_routes(&self) -> Vec { + self.inner.read().unwrap().static_routes.clone() + } + + pub fn peer_by_ip(&self, peer_ip: IpAddr) -> Option { + self.inner.read().unwrap().peer_by_ip(peer_ip) + } + + pub fn peer_exists(&self, peer_underlay_ip: IpAddr) -> bool { + self.inner.read().unwrap().peer_exists(peer_underlay_ip) + } + + pub fn source_peer_from_control_struct(&self, control_struct: ControlStruct) -> Option { + let peers = self.peer_interfaces(); + let matching_peer = peers + .iter() + .find(|peer| peer.overlay_ip() == control_struct.src_overlay_ip); + + matching_peer.map(Clone::clone) + } + + + pub fn print_selected_routes(&self) { + let inner = self.inner.read().unwrap(); + + let routing_table = &inner.selected_routing_table; + for route in routing_table.table.iter() { + println!("Route key: {:?}", route.0); + println!( + "Route: {:?}/{:?} (with next-hop: {:?}, metric: {}, selected: {})", + route.0.prefix, route.0.plen, route.1.next_hop, route.1.metric, route.1.selected + ); + println!("As advertised by: {:?}", route.1.source.router_id); + } + } + + pub fn print_fallback_routes(&self) { + let inner = self.inner.read().unwrap(); + + let routing_table = &inner.fallback_routing_table; + for route in routing_table.table.iter() { + println!("Route key: {:?}", route.0); + println!( + "Route: {:?}/{:?} (with next-hop: {:?}, metric: {}, selected: {})", + route.0.prefix, route.0.plen, route.1.next_hop, route.1.metric, route.1.selected + ); + println!("As advertised by: {:?}", route.1.source.router_id); + } + } + + pub fn print_source_table(&self) { + let inner = self.inner.read().unwrap(); + + let source_table = &inner.source_table; + for (sk, se) in source_table.table.iter() { + println!("Source key: {:?}", sk); + println!("Source entry: {:?}", se); + println!("\n"); + } + } + + async fn check_for_dead_peers(self) { + + let ihu_threshold = tokio::time::Duration::from_secs(8); + + loop { + + // check for dead peers every second + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + let mut inner = self.inner.write().unwrap(); + + let dead_peers = { + // a peer is assumed dead when the peer's last sent ihu exceeds a threshold + let mut dead_peers = Vec::new(); + for peer in inner.peer_interfaces.iter() { + // check if the peer's last_received_ihu is greater than the threshold + if peer.time_last_received_ihu().elapsed() > ihu_threshold { + // peer is dead + println!("Peer {:?} is dead", peer.overlay_ip()); + dead_peers.push(peer.clone()); + } + } + dead_peers + }; + + // vec to store retraction update that need to be sent + let mut retraction_updates = Vec::::new(); + + // remove the peer from the peer_interfaces and the routes + for dead_peer in dead_peers { + inner.remove_peer_interface(dead_peer.clone()); + // remove the peer's routes from all routing tables (= all the peers that use the peer as next-hop) + inner.selected_routing_table.table.retain(|_, route_entry| { + route_entry.next_hop != dead_peer.overlay_ip() + }); + inner.fallback_routing_table.table.retain(|_, route_entry| { + route_entry.next_hop != dead_peer.overlay_ip() + }); + + // create retraction update for each dead peer + let retraction_update = ControlPacket::new_update( + 32, + UPDATE_INTERVAL as u16, + inner.router_seqno, + 0xFFFF, + dead_peer.overlay_ip(), // todo: fix to use actual prefix, not IP + inner.router_id, + ); + retraction_updates.push(retraction_update); + } + + // send retraction update for the dead peer + // when other nodes receive this update (with metric 0XFFFF), they should also remove the routing tables entries with that peer as neighbor + for peer in inner.peer_interfaces.iter() { + for ru in retraction_updates.iter() { + if let Err(e) = peer.send_control_packet(ru.clone()) { + eprintln!("Error sending retraction update to peer"); + } + } + } + + } + } + + + + async fn handle_incoming_control_packet( + self, + mut router_control_rx: UnboundedReceiver, + ) { + while let Some(control_struct) = router_control_rx.recv().await { + match control_struct.control_packet.body.tlv_type { + BabelTLVType::AckReq => todo!(), + BabelTLVType::Ack => todo!(), + BabelTLVType::Hello => Self::handle_incoming_hello(&self, control_struct), + BabelTLVType::IHU => Self::handle_incoming_ihu(&self, control_struct), + BabelTLVType::NextHop => todo!(), + BabelTLVType::Update => Self::handle_incoming_update(&self, control_struct), + BabelTLVType::RouteReq => todo!(), + BabelTLVType::SeqnoReq => todo!(), + } + } + } + + fn handle_incoming_hello(&self, control_struct: ControlStruct) { + // let destination_ip = control_struct.src_overlay_ip; + // control_struct.reply(ControlPacket::new_ihu(IHU_INTERVAL, destination_ip)); + + // Upon receiving and Hello message from a peer, this node has to send a IHU back + if let Some(source_peer) = self.source_peer_from_control_struct(control_struct) { + let ihu = ControlPacket::new_ihu(IHU_INTERVAL, source_peer.overlay_ip()); + match source_peer.send_control_packet(ihu) { + Ok(()) => { + }, + Err(e) => { + eprintln!("Error sending IHU to peer: {e}"); + } + } + } + } + + fn handle_incoming_ihu(&self, control_struct: ControlStruct) { + if let Some(source_peer) = self.source_peer_from_control_struct(control_struct) { + // reset the IHU timer associated with the peer + // source_peer.reset_ihu_timer(tokio::time::Duration::from_secs(IHU_INTERVAL as u64)); + // measure time between Hello and and IHU and set the link cost + let time_diff = tokio::time::Instant::now() + .duration_since(source_peer.time_last_received_hello()) + .as_millis(); + + source_peer.set_link_cost(time_diff as u16); + + + + // set the last_received_ihu for this peer + source_peer.set_time_last_received_ihu(tokio::time::Instant::now()); + } + } + + // incoming update can only be received by a Peer this node has a direct link to + fn handle_incoming_update(&self, update: ControlStruct) { + match update.control_packet.body.tlv { + BabelTLV::Update { plen, interval: _, seqno, metric, prefix, router_id } => { + + // create route key from incoming update control struct + // we need the address of the neighbour; this corresponds to the source ip of the control struct as the update is received from the neighbouring peer + let neighbor_ip = update.src_overlay_ip; + let route_key_from_update = RouteKey { + neighbor: neighbor_ip, + plen, + prefix, + }; + + // used later to filter out static route + if self.route_key_is_from_static_route(&route_key_from_update) { + return; + } + + let mut inner = self.inner.write().unwrap(); + + // check if a route entry with the same route key exists in both routing tables + let route_entry_exists = inner.selected_routing_table.table.contains_key(&route_key_from_update) || inner.fallback_routing_table.table.contains_key(&route_key_from_update); + + // if no entry exists (based on prefix, plen AND neighbor field) + if !route_entry_exists { + // if the update is unfeasible, or the metric is inifinite, we ignore the update + if metric == u16::MAX || !self.update_feasible(&update, &inner.source_table) { + return; + } + else { + // this means that the update is feasible and the metric is not infinite + // create a new route entry and add it to the routing table (which requires a new source entry to be created as well) + + let source_key = SourceKey { prefix, plen, router_id }; + let fd = FeasibilityDistance{ metric, seqno }; + inner.source_table.insert(source_key, fd); + + let route_key = RouteKey { + prefix, + plen, + neighbor: neighbor_ip, + }; + let route_entry = RouteEntry { + source: source_key, + neighbor: inner.peer_by_ip(neighbor_ip).unwrap(), + metric, + seqno, + next_hop: neighbor_ip, + selected: true, + }; + + // Collect keys of routes to be removed + let mut to_remove = Vec::new(); + for r in inner.selected_routing_table.table.iter() { + // filter based on prefix and plen, skipping neighbor + if r.0.plen == plen && r.0.prefix == prefix { + // metric of update is smaller than entry's metric + if metric < r.1.metric { + // this means we should remove the entry from the selected routing table + to_remove.push(r.0.clone()); + break; // we can break, as there will be max 1 better route in selected table at any point in time (hence 'selected') + // metric of update is greater than entry's metric + } else if metric >= r.1.metric { + // this means that there is already a better route in our selected routing table, + // so we should add it to fallback instead + inner.fallback_routing_table.table.insert(route_key.clone(), route_entry.clone()); + return; // quit the function, work is done here + } + } + } + // Remove better routes from selected and insert into fallback + for rk in to_remove { + if let Some(old_selected) = inner.selected_routing_table.remove(&rk) { + inner.fallback_routing_table.insert(rk, old_selected); + } + } + // insert the route into selected (we might have placed one other route, that was previously the best, in the fallback) + inner.selected_routing_table.table.insert(route_key, route_entry); + + } + } + // entry exists + else { + // check if update is a retraction + if self.update_feasible(&update, &inner.source_table) && metric == u16::MAX { + // if the update is a retraction, we remove the entry from the routing tables + // we also remove the corresponding source entry??? + if inner.selected_routing_table.table.contains_key(&route_key_from_update) { + inner.selected_routing_table.remove(&route_key_from_update); + } + if inner.fallback_routing_table.table.contains_key(&route_key_from_update) { + inner.fallback_routing_table.remove(&route_key_from_update); + } + // remove the corresponding source entry + let source_key = SourceKey { prefix, plen, router_id }; + inner.source_table.remove(&source_key); + + return; + } + // if the entry is currently selected, the update is unfeasible, and the router-id of the update is equal + // to the router-id of the entry, then we ignore the update + if inner.selected_routing_table.table.contains_key(&route_key_from_update) { + let route_entry = inner.selected_routing_table.table.get(&route_key_from_update).unwrap(); + if !self.update_feasible(&update, &inner.source_table) && route_entry.source.router_id == router_id { + return; + } + // update the entry's seqno, metric and router-id + let route_entry = inner.selected_routing_table.table.get_mut(&route_key_from_update).unwrap(); + route_entry.update_seqno(seqno); + route_entry.update_metric(metric); + route_entry.update_router_id(router_id); + } + // otherwise + else { + let route_entry = inner.fallback_routing_table.table.get_mut(&route_key_from_update).unwrap(); + // update the entry's seqno, metric and router-id + route_entry.update_seqno(seqno); + route_entry.update_metric(metric); + route_entry.update_router_id(router_id); + + if !self.update_feasible(&update, &inner.source_table) { + // if the update is unfeasible, we remove the entry from the selected routing table + inner.selected_routing_table.table.remove(&route_key_from_update); + // should we remove it from the selected and add it to fallback here??? + } + } + } + }, + _ => { + panic!("Received update with wrong TLV type"); + } + } + + } + + fn route_key_is_from_static_route(&self, route_key: &RouteKey) -> bool { + let inner = self.inner.read().unwrap(); + + for sr in inner.static_routes.iter() { + if sr.plen == route_key.plen && sr.prefix == route_key.prefix { + return true; + } + } + return false; + } + + // we gebruiken self niet in de functie --> daarop functie eigenllijk beter op de source table implementeren + fn update_feasible(&self, update: &ControlStruct, source_table: &SourceTable) -> bool { + // Before an update is accepted it should be checked against the feasbility condition + // If an entry in the source table with the same source key exists, we perform the feasbility check + // If no entry exists yet, the update is accepted as there is no better alternative available (yet) + match update.control_packet.body.tlv { + BabelTLV::Update { + plen, + interval: _, + seqno, + metric, + prefix, + router_id, + } => { + let source_key = SourceKey { + prefix, + plen, + router_id, + }; + match source_table.get(&source_key) { + Some(&entry) => { + return (seqno > entry.seqno|| (seqno == entry.seqno && metric < entry.metric)) || metric == 0xFFFF; + } + None => return true, + } + } + _ => { + eprintln!("Error accepting update, control struct did not match update packet"); + return false; + } + } + } + + async fn handle_incoming_data_packet(self, mut router_data_rx: UnboundedReceiver) { + // If destination IP of packet is same as TUN interface IP, send to TUN interface + // If destination IP of packet is not same as TUN interface IP, send to peer with matching overlay IP + let node_tun = self.node_tun(); + let node_tun_addr = node_tun.address().unwrap(); + loop { + while let Some(data_packet) = router_data_rx.recv().await { + match data_packet.dest_ip { + x if x == node_tun_addr => match node_tun.send(&data_packet.raw_data).await { + Ok(_) => {} + Err(e) => { + eprintln!("Error sending data packet to TUN interface: {:?}", e) + } + }, + _ => { + let best_route = self.select_best_route(IpAddr::V4(data_packet.dest_ip)); + match best_route { + Some (route_entry) => { + let peer = self.peer_by_ip(route_entry.next_hop).unwrap(); + if let Err(e) = peer.send_data_packet(data_packet) { + eprintln!("Error sending data packet to peer: {:?}", e); + } + }, + None => { + eprintln!("Error sending data packet, no route found"); + } + } + } + } + } + } + } + + pub fn select_best_route(&self, dest_ip: IpAddr) -> Option { + + let inner = self.inner.read().unwrap(); + let mut best_route = None; + // first look in the selected routing table for a match on the prefix of dest_ip + for (route_key, route_entry) in inner.selected_routing_table.table.iter() { + if route_key.prefix == dest_ip { + best_route = Some(route_entry.clone()); + } + } + // if no match was found, look in the fallback routing table + if best_route.is_none() { + println!("no match in selected routing table, looking in fallback routing table"); + for (route_key, route_entry) in inner.fallback_routing_table.table.iter() { + if route_key.prefix == dest_ip { + best_route = Some(route_entry.clone()); + } + } + } + + println!("\n\n best route towards {}: {:?}", dest_ip, best_route); + + return best_route + } + + pub async fn propagate_static_route(self) { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + + let mut inner = self.inner.write().unwrap(); + inner.propagate_static_route(); + } + } + + pub async fn propagate_selected_routes(self) { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + + let mut inner = self.inner.write().unwrap(); + inner.propagate_selected_routes(); + } + } + + async fn start_periodic_hello_sender(self) { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(HELLO_INTERVAL as u64)).await; + + for peer in self.peer_interfaces().iter_mut() { + let hello = ControlPacket::new_hello(peer, HELLO_INTERVAL); + peer.set_time_last_received_hello(tokio::time::Instant::now()); + + if let Err(error) = peer.send_control_packet(hello) { + eprintln!("Error sending hello to peer: {}", error); + } + } + } + } +} + +pub struct RouterInner { + pub router_id: u64, + peer_interfaces: Vec, + router_control_tx: UnboundedSender, + router_data_tx: UnboundedSender, + node_tun: Arc, + selected_routing_table: RoutingTable, + fallback_routing_table: RoutingTable, + source_table: SourceTable, + router_seqno: u16, + static_routes: Vec, +} + +impl RouterInner { + pub fn new( + node_tun: Arc, + static_routes: Vec, + router_data_tx: UnboundedSender, + router_control_tx: UnboundedSender, + ) -> Result> { + let router_inner = RouterInner { + router_id: rand::thread_rng().gen(), + peer_interfaces: Vec::new(), + router_control_tx, + router_data_tx, + node_tun: node_tun, + selected_routing_table: RoutingTable::new(), + fallback_routing_table: RoutingTable::new(), + source_table: SourceTable::new(), + router_seqno: 0, + static_routes: static_routes, + }; + + Ok(router_inner) + } + fn remove_peer_interface(&mut self, peer: Peer) { + self.peer_interfaces.retain(|p| p != &peer); + } + + fn peer_by_ip(&self, peer_ip: IpAddr) -> Option { + let matching_peer = self + .peer_interfaces + .iter() + .find(|peer| peer.overlay_ip() == peer_ip); + + matching_peer.map(Clone::clone) + } + + fn send_update(&mut self, peer: &Peer, update: ControlPacket) { + // before sending an update, the source table might need to be updated + match update.body.tlv { + BabelTLV::Update { + plen, + interval: _, + seqno, + metric, + prefix, + router_id, + } => { + let source_key = SourceKey { + prefix, + plen, + router_id, + }; + + if let Some(source_entry) = self.source_table.get(&source_key) { + // if seqno of the update is greater than the seqno in the source table, update the source table + if seqno > source_entry.metric { + self.source_table + .insert(source_key, FeasibilityDistance::new(metric, seqno)); + } + // if seqno of the update is equal to the seqno in the source table, update the source table if the metric (of the update) is lower + else if seqno == source_entry.seqno && source_entry.metric > metric { + self.source_table.insert( + source_key, + FeasibilityDistance::new(metric, source_entry.seqno), + ); + } + } + // no entry for this source key, so insert it + else { + self.source_table + .insert(source_key, FeasibilityDistance::new(metric, seqno)); + } + + // send the update to the peer + if let Err(e) = peer.send_control_packet(update) { + println!("Error sending update to peer: {:?}", e); + } + } + _ => { + panic!("Control packet is not a correct Update packet"); + } + } + } + + fn propagate_static_route(&mut self) { + let mut updates = vec![]; + for sr in self.static_routes.iter() { + for peer in self.peer_interfaces.iter() { + let update = ControlPacket::new_update( + sr.plen, // static routes have plen 32 + UPDATE_INTERVAL as u16, + self.router_seqno, // updates receive the seqno of the router + peer.link_cost(), // direct connection to other peer, so the only cost is the cost towards the peer + sr.prefix, // the prefix of a static route corresponds to the TUN addr of the node + self.router_id, + ); + updates.push((peer.clone(), update)); + } + } + for (peer, update) in updates { + self.send_update(&peer, update); + } + } + + fn propagate_selected_routes(&mut self) { + let mut updates = vec![]; + for sr in self.selected_routing_table.table.iter() { + for peer in self.peer_interfaces.iter() { + + let peer_link_cost = peer.link_cost(); + + let update = ControlPacket::new_update( + sr.0.plen, + UPDATE_INTERVAL as u16, + self.router_seqno, // updates receive the seqno of the router + if sr.1.metric > u16::MAX -1 - peer_link_cost {u16::MAX - 1 } else { sr.1.metric + peer_link_cost }, // the cost of the route is the cost of the route + the cost of the link to the peer + sr.0.prefix, // the prefix of a static route corresponds to the TUN addr of the node + self.router_id, + ); + updates.push((peer.clone(), update)); + } + } + + for (peer, update) in updates { + self.send_update(&peer, update); + } + } + + fn peer_exists(&self, peer_underlay_ip: IpAddr) -> bool { + self.peer_interfaces + .iter() + .any(|peer| peer.underlay_ip() == peer_underlay_ip) + } +} diff --git a/src/routing.rs b/src/routing.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/routing_table.rs b/src/routing_table.rs new file mode 100644 index 0000000..728edf5 --- /dev/null +++ b/src/routing_table.rs @@ -0,0 +1,83 @@ +use crate::{peer::Peer, source_table::SourceKey}; +use std::{collections::BTreeMap, net::IpAddr}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct RouteKey { + pub prefix: IpAddr, + pub plen: u8, + pub neighbor: IpAddr, +} + +#[derive(Debug, Clone)] +pub struct RouteEntry { + pub source: SourceKey, + pub neighbor: Peer, + pub metric: u16, // If metric is 0xFFFF, the route has recently been retracted + pub seqno: u16, + pub next_hop: IpAddr, // This is the Peer's address + pub selected: bool, + //pub route_expiry_timer: Timer, +} + +impl RouteEntry { + /* + pub fn new( + source: SourceKey, + neighbor: Peer, + metric: u16, + seqno: u16, + next_hop: IpAddr, + selected: bool, + ) -> Self { + Self { + source, + neighbor, + metric, + seqno, + next_hop, + selected, + } + } + + pub fn retracted(&mut self) { + self.metric = 0xFFFF; + } + + pub fn is_retracted(&self) -> bool { + self.metric == 0xFFFF + } + */ + pub fn update_metric(&mut self, metric: u16) { + self.metric = metric; + } + + pub fn update_seqno(&mut self, seqno: u16) { + self.seqno = seqno; + } + + pub fn update_router_id(&mut self, router_id: u64) { + self.source.router_id = router_id; + } +} + +#[derive(Debug, Clone)] +pub struct RoutingTable { + pub table: BTreeMap, +} + +impl RoutingTable { + pub fn new() -> Self { + Self { + table: BTreeMap::new(), + } + } + + pub fn insert(&mut self, key: RouteKey, entry: RouteEntry) { + self.table.insert(key, entry); + //println!("Added route to routing table: {:?}", self.table); + } + + pub fn remove(&mut self, key: &RouteKey) -> Option { + self.table.remove(key) + } +} diff --git a/src/source_table.rs b/src/source_table.rs new file mode 100644 index 0000000..2486a0d --- /dev/null +++ b/src/source_table.rs @@ -0,0 +1,121 @@ +use std::{collections::HashMap, net::IpAddr}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] +pub struct SourceKey { + pub prefix: IpAddr, + pub plen: u8, + pub router_id: u64, // We temporarily use 100 for all router IDs +} + +#[derive(Debug, Clone, Copy)] +pub struct FeasibilityDistance { + pub metric: u16, + pub seqno: u16, +} + +impl FeasibilityDistance { + pub fn new(metric: u16, seqno: u16) -> Self { + FeasibilityDistance { metric, seqno } + } +} + +// Store (prefix, plen, router_id) -> feasibility distance mapping + +#[derive(Debug)] +pub struct SourceTable { + pub table: HashMap, +} + +impl SourceTable { + pub fn new() -> Self { + Self { + table: HashMap::new(), + } + } + + pub fn insert(&mut self, key: SourceKey, feas_dist: FeasibilityDistance) { + self.table.insert(key, feas_dist); + } + + pub fn remove(&mut self, key: &SourceKey) { + self.table.remove(key); + } + + pub fn get(&self, key: &SourceKey) -> Option<&FeasibilityDistance> { + self.table.get(key) + } + + // pub fn update(&mut self, update: &ControlStruct) { + // match update.control_packet.body.tlv { + // BabelTLV::Update { + // plen, + // interval, + // seqno, + // metric, + // prefix, + // router_id, + // } => { + // // first check if the update is feasible + // if !self.is_feasible(update) { + // return; + // } + + // let key = SourceKey { + // prefix: prefix, + // plen: plen, + // router_id: router_id, + // }; + + // let new_distance = FeasibilityDistance(metric, seqno); + // let old_distance = self.table.get(&key).cloned(); + // match old_distance { + // Some(old_distance) => { + // if new_distance.0 < old_distance.0 { + // self.table + // .insert(key, FeasibilityDistance(new_distance.0, new_distance.1)); + // } + // } + // None => { + // self.table + // .insert(key, FeasibilityDistance(new_distance.0, new_distance.1)); + // } + // } + // } + // _ => { + // panic!("not an update"); + // } + // } + // } + + // pub fn is_feasible(&self, update: &ControlStruct) -> bool { + // match update.control_packet.body.tlv { + // BabelTLV::Update { + // plen, + // interval: _, + // seqno, + // metric, + // prefix, + // router_id, + // } => { + // let key = SourceKey { + // prefix: prefix, + // plen: plen, + // router_id: router_id, + // }; + + // match self.table.get(&key) { + // Some(&source_entry) => { + // let metric_2 = source_entry.0; + // let seqno_2 = source_entry.1; + + // seqno > seqno_2 || (seqno == seqno_2 && metric < metric_2) + // } + // None => true, + // } + // } + // _ => { + // panic!("not an update"); + // } + // } + // } +}