From 20fccac2b7a623b75dc892d25ccc452a30f08332 Mon Sep 17 00:00:00 2001 From: Scott Powell Date: Tue, 14 Jan 2025 06:43:03 +1100 Subject: [PATCH] * refactored the hasSeen(Packet) stuff. --- examples/ping_client/main.cpp | 15 +++----- examples/ping_server/main.cpp | 10 +++--- examples/simple_repeater/main.cpp | 13 ++----- examples/simple_secure_chat/main.cpp | 14 +++----- examples/test_admin/main.cpp | 15 +++----- src/Mesh.cpp | 20 ++++++----- src/Mesh.h | 13 +++++-- src/MeshTables.h | 16 --------- src/helpers/SimpleMeshTables.h | 52 ++++++++++++---------------- src/helpers/SimpleSeenTable.h | 33 ------------------ 10 files changed, 68 insertions(+), 133 deletions(-) delete mode 100644 src/MeshTables.h delete mode 100644 src/helpers/SimpleSeenTable.h diff --git a/examples/ping_client/main.cpp b/examples/ping_client/main.cpp index 7d7305ae..8259cf58 100644 --- a/examples/ping_client/main.cpp +++ b/examples/ping_client/main.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include /* ------------------------------ Config -------------------------------- */ @@ -21,7 +21,6 @@ /* ------------------------------ Code -------------------------------- */ class MyMesh : public mesh::Mesh { - SimpleSeenTable* _table; uint32_t last_advert_timestamp = 0; mesh::Identity server_id; uint8_t server_secret[PUB_KEY_SIZE]; @@ -54,8 +53,6 @@ protected: void onPeerDataRecv(mesh::Packet* packet, uint8_t type, int sender_idx, uint8_t* data, size_t len) override { if (type == PAYLOAD_TYPE_RESPONSE) { - if (_table->hasSeenPacket(packet)) return; - Serial.println("Received PING Reply!"); if (packet->isRouteFlood()) { @@ -67,8 +64,6 @@ protected: } void onPeerPathRecv(mesh::Packet* packet, int sender_idx, uint8_t* path, uint8_t path_len, uint8_t extra_type, uint8_t* extra, uint8_t extra_len) override { - if (_table->hasSeenPacket(packet)) return; - // must be from server_id Serial.printf("PATH to server, path_len=%d\n", (uint32_t) path_len); @@ -86,8 +81,8 @@ protected: } public: - MyMesh(mesh::Radio& radio, mesh::RNG& rng, mesh::RTCClock& rtc, SimpleSeenTable& table) - : mesh::Mesh(radio, *new ArduinoMillis(), rng, rtc, *new StaticPoolPacketManager(16)), _table(&table) + MyMesh(mesh::Radio& radio, mesh::RNG& rng, mesh::RTCClock& rtc, mesh::MeshTables& tables) + : mesh::Mesh(radio, *new ArduinoMillis(), rng, rtc, *new StaticPoolPacketManager(16), tables) { } @@ -110,9 +105,9 @@ public: SPIClass spi; StdRNG fast_rng; -SimpleSeenTable table; +SimpleMeshTables tables; SX1262 radio = new Module(P_LORA_NSS, P_LORA_DIO_1, P_LORA_RESET, P_LORA_BUSY, spi); -MyMesh the_mesh(*new RadioLibWrapper(radio, board), fast_rng, *new VolatileRTCClock(), table); +MyMesh the_mesh(*new RadioLibWrapper(radio, board), fast_rng, *new VolatileRTCClock(), tables); unsigned long nextPing; void halt() { diff --git a/examples/ping_server/main.cpp b/examples/ping_server/main.cpp index 4db9d238..61fef4c3 100644 --- a/examples/ping_server/main.cpp +++ b/examples/ping_server/main.cpp @@ -7,6 +7,7 @@ #include #include #include +#include /* ------------------------------ Config -------------------------------- */ @@ -56,7 +57,7 @@ protected: auto client = putClient(sender); // add to known clients (if not already known) if (client == NULL || timestamp <= client->last_timestamp) { - return; // FATAL: client table is full -OR- replay attack -OR- have seen this packet before + return; // FATAL: client table is full -OR- replay attack } client->last_timestamp = timestamp; @@ -123,8 +124,8 @@ protected: } public: - MyMesh(mesh::Radio& radio, mesh::MillisecondClock& ms, mesh::RNG& rng, mesh::RTCClock& rtc) - : mesh::Mesh(radio, ms, rng, rtc, *new StaticPoolPacketManager(16)) + MyMesh(mesh::Radio& radio, mesh::MillisecondClock& ms, mesh::RNG& rng, mesh::RTCClock& rtc, mesh::MeshTables& tables) + : mesh::Mesh(radio, ms, rng, rtc, *new StaticPoolPacketManager(16), tables) { num_clients = 0; } @@ -132,8 +133,9 @@ public: SPIClass spi; StdRNG fast_rng; +SimpleMeshTables tables; SX1262 radio = new Module(P_LORA_NSS, P_LORA_DIO_1, P_LORA_RESET, P_LORA_BUSY, spi); -MyMesh the_mesh(*new RadioLibWrapper(radio, board), *new ArduinoMillis(), fast_rng, *new VolatileRTCClock()); +MyMesh the_mesh(*new RadioLibWrapper(radio, board), *new ArduinoMillis(), fast_rng, *new VolatileRTCClock(), tables); unsigned long nextAnnounce; diff --git a/examples/simple_repeater/main.cpp b/examples/simple_repeater/main.cpp index b2d87696..4bd4361e 100644 --- a/examples/simple_repeater/main.cpp +++ b/examples/simple_repeater/main.cpp @@ -53,7 +53,6 @@ struct ClientInfo { class MyMesh : public mesh::Mesh { RadioLibWrapper* my_radio; - mesh::MeshTables* _tables; float airtime_factor; uint8_t reply_data[MAX_PACKET_PAYLOAD]; int num_clients; @@ -145,12 +144,6 @@ protected: } bool allowPacketForward(const mesh::Packet* packet) override { - uint8_t hash[MAX_HASH_SIZE]; - packet->calculatePacketHash(hash); - - if (_tables->hasForwarded(hash)) return false; // has already been forwarded - - _tables->setHasForwarded(hash); // mark packet as forwarded return true; // Yes, allow packet to be forwarded } @@ -162,7 +155,7 @@ protected: if (memcmp(&data[4], ADMIN_PASSWORD, 8) == 0) { // check for valid password auto client = putClient(sender); // add to known clients (if not already known) if (client == NULL || timestamp <= client->last_timestamp) { - return; // FATAL: client table is full -OR- replay attack -OR- have seen this packet before + return; // FATAL: client table is full -OR- replay attack } client->last_timestamp = timestamp; @@ -222,7 +215,7 @@ protected: uint32_t timestamp; memcpy(×tamp, data, 4); - if (timestamp > client->last_timestamp) { // prevent replay attacks AND receiving via multiple paths + if (timestamp > client->last_timestamp) { // prevent replay attacks int reply_len = handleRequest(client, &data[4], len - 4); if (reply_len == 0) return; // invalid command @@ -267,7 +260,7 @@ protected: public: MyMesh(RadioLibWrapper& radio, mesh::MillisecondClock& ms, mesh::RNG& rng, mesh::RTCClock& rtc, mesh::MeshTables& tables) - : mesh::Mesh(radio, ms, rng, rtc, *new StaticPoolPacketManager(32)), _tables(&tables) + : mesh::Mesh(radio, ms, rng, rtc, *new StaticPoolPacketManager(32), tables) { my_radio = &radio; airtime_factor = 5.0; // 1/6th diff --git a/examples/simple_secure_chat/main.cpp b/examples/simple_secure_chat/main.cpp index 2cbbdabc..f0e9285d 100644 --- a/examples/simple_secure_chat/main.cpp +++ b/examples/simple_secure_chat/main.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include /* ---------------------------------- CONFIGURATION ------------------------------------- */ @@ -52,7 +52,6 @@ struct ContactInfo { class MyMesh : public mesh::Mesh { public: - SimpleSeenTable* _table; ContactInfo contacts[MAX_CONTACTS]; int num_contacts; @@ -111,7 +110,6 @@ protected: // NOTE: this is a 'first packet wins' impl. When receiving from multiple paths, the first to arrive wins. // For flood mode, the path may not be the 'best' in terms of hops. // FUTURE: could send back multiple paths, using createPathReturn(), and let sender choose which to use(?) - if (_table->hasSeenPacket(packet)) return; int i = matching_peer_indexes[sender_idx]; if (i < 0 && i >= num_contacts) { @@ -152,8 +150,6 @@ protected: } void onPeerPathRecv(mesh::Packet* packet, int sender_idx, uint8_t* path, uint8_t path_len, uint8_t extra_type, uint8_t* extra, uint8_t extra_len) override { - if (_table->hasSeenPacket(packet)) return; - int i = matching_peer_indexes[sender_idx]; if (i < 0 && i >= num_contacts) { MESH_DEBUG_PRINTLN("onPeerPathRecv: Invalid sender idx: %d", i); @@ -199,8 +195,8 @@ protected: public: uint32_t expected_ack_crc; - MyMesh(mesh::Radio& radio, mesh::RNG& rng, mesh::RTCClock& rtc, SimpleSeenTable& table) - : mesh::Mesh(radio, *new ArduinoMillis(), rng, rtc, *new StaticPoolPacketManager(16)), _table(&table) + MyMesh(mesh::Radio& radio, mesh::RNG& rng, mesh::RTCClock& rtc, SimpleMeshTables& tables) + : mesh::Mesh(radio, *new ArduinoMillis(), rng, rtc, *new StaticPoolPacketManager(16), tables) { num_contacts = 0; } @@ -233,9 +229,9 @@ public: SPIClass spi; StdRNG fast_rng; -SimpleSeenTable table; +SimpleMeshTables tables; SX1262 radio = new Module(P_LORA_NSS, P_LORA_DIO_1, P_LORA_RESET, P_LORA_BUSY, spi); -MyMesh the_mesh(*new RadioLibWrapper(radio, board), fast_rng, *new VolatileRTCClock(), table); +MyMesh the_mesh(*new RadioLibWrapper(radio, board), fast_rng, *new VolatileRTCClock(), tables); void halt() { while (1) ; diff --git a/examples/test_admin/main.cpp b/examples/test_admin/main.cpp index 4816150f..3cc964ae 100644 --- a/examples/test_admin/main.cpp +++ b/examples/test_admin/main.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include /* ---------------------------------- CONFIGURATION ------------------------------------- */ @@ -41,7 +41,6 @@ struct RepeaterStats { }; class MyMesh : public mesh::Mesh { - SimpleSeenTable* _table; uint32_t last_advert_timestamp = 0; mesh::Identity server_id; uint8_t server_secret[PUB_KEY_SIZE]; @@ -105,8 +104,6 @@ protected: void onPeerDataRecv(mesh::Packet* packet, uint8_t type, int sender_idx, uint8_t* data, size_t len) override { if (type == PAYLOAD_TYPE_RESPONSE) { - if (_table->hasSeenPacket(packet)) return; - handleResponse(data, len); if (packet->isRouteFlood()) { @@ -118,8 +115,6 @@ protected: } void onPeerPathRecv(mesh::Packet* packet, int sender_idx, uint8_t* path, uint8_t path_len, uint8_t extra_type, uint8_t* extra, uint8_t extra_len) override { - if (_table->hasSeenPacket(packet)) return; - // must be from server_id Serial.printf("PATH to repeater, path_len=%d\n", (uint32_t) path_len); @@ -137,8 +132,8 @@ protected: } public: - MyMesh(mesh::Radio& radio, mesh::RNG& rng, mesh::RTCClock& rtc, SimpleSeenTable& table) - : mesh::Mesh(radio, *new ArduinoMillis(), rng, rtc, *new StaticPoolPacketManager(16)), _table(&table) + MyMesh(mesh::Radio& radio, mesh::RNG& rng, mesh::RTCClock& rtc, mesh::MeshTables& tables) + : mesh::Mesh(radio, *new ArduinoMillis(), rng, rtc, *new StaticPoolPacketManager(16), tables) { } @@ -206,14 +201,14 @@ public: }; StdRNG fast_rng; -SimpleSeenTable table; +SimpleMeshTables tables; #if defined(P_LORA_SCLK) SPIClass spi; CustomSX1262 radio = new Module(P_LORA_NSS, P_LORA_DIO_1, P_LORA_RESET, P_LORA_BUSY, spi); #else CustomSX1262 radio = new Module(P_LORA_NSS, P_LORA_DIO_1, P_LORA_RESET, P_LORA_BUSY); #endif -MyMesh the_mesh(*new CustomSX1262Wrapper(radio, board), fast_rng, *new VolatileRTCClock(), table); +MyMesh the_mesh(*new CustomSX1262Wrapper(radio, board), fast_rng, *new VolatileRTCClock(), tables); void halt() { while (1) ; diff --git a/src/Mesh.cpp b/src/Mesh.cpp index 7ff55e9f..329eaf5d 100644 --- a/src/Mesh.cpp +++ b/src/Mesh.cpp @@ -36,6 +36,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { if (pkt->isRouteDirect() && pkt->path_len >= PATH_HASH_SIZE) { if (self_id.isHashMatch(pkt->path) && allowPacketForward(pkt)) { + if (_tables->hasSeen(pkt)) return ACTION_RELEASE; // don't retransmit! + // remove our hash from 'path', then re-broadcast pkt->path_len -= PATH_HASH_SIZE; memcpy(pkt->path, &pkt->path[PATH_HASH_SIZE], pkt->path_len); @@ -53,7 +55,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { memcpy(&ack_crc, &pkt->payload[i], 4); i += 4; if (i > pkt->payload_len) { MESH_DEBUG_PRINTLN("Mesh::onRecvPacket(): incomplete ACK packet"); - } else { + } else if (!_tables->hasSeen(pkt)) { onAckRecv(pkt, ack_crc); action = routeRecvPacket(pkt); } @@ -70,7 +72,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("Mesh::onRecvPacket(): incomplete data packet"); - } else { + } else if (!_tables->hasSeen(pkt)) { if (self_id.isHashMatch(&dest_hash)) { // scan contacts DB, for all matching hashes of 'src_hash' (max 4 matches supported ATM) int num = searchPeersByHash(&src_hash); @@ -115,7 +117,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("Mesh::onRecvPacket(): incomplete data packet"); - } else { + } else if (!_tables->hasSeen(pkt)) { if (self_id.isHashMatch(&dest_hash)) { Identity sender(sender_pub_key); @@ -141,7 +143,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("Mesh::onRecvPacket(): incomplete data packet"); - } else { + } else if (!_tables->hasSeen(pkt)) { // scan channels DB, for all matching hashes of 'channel_hash' (max 2 matches supported ATM) GroupChannel channels[2]; int num = searchChannelsByHash(&channel_hash, channels, 2); @@ -170,7 +172,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { if (i > pkt->payload_len) { MESH_DEBUG_PRINTLN("Mesh::onRecvPacket(): incomplete advertisement packet"); - } else { + } else if (!_tables->hasSeen(pkt)) { uint8_t* app_data = &pkt->payload[i]; int app_data_len = pkt->payload_len - i; if (app_data_len > MAX_ADVERT_DATA_SIZE) { app_data_len = MAX_ADVERT_DATA_SIZE; } @@ -198,7 +200,7 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { } default: MESH_DEBUG_PRINTLN("Mesh::onRecvPacket(): unknown payload type, header: %d", (int) pkt->header); - action = routeRecvPacket(pkt); + // Don't flood route unknown packet types! action = routeRecvPacket(pkt); break; } return action; @@ -382,7 +384,7 @@ void Mesh::sendFlood(Packet* packet, uint32_t delay_millis) { packet->header |= ROUTE_TYPE_FLOOD; packet->path_len = 0; - allowPacketForward(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us uint8_t pri; if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) { @@ -401,7 +403,7 @@ void Mesh::sendDirect(Packet* packet, const uint8_t* path, uint8_t path_len, uin memcpy(packet->path, path, packet->path_len = path_len); - allowPacketForward(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us sendPacket(packet, 0, delay_millis); } @@ -412,7 +414,7 @@ void Mesh::sendZeroHop(Packet* packet, uint32_t delay_millis) { packet->path_len = 0; // path_len of zero means Zero Hop - allowPacketForward(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us sendPacket(packet, 0, delay_millis); } diff --git a/src/Mesh.h b/src/Mesh.h index 746179d5..06a47790 100644 --- a/src/Mesh.h +++ b/src/Mesh.h @@ -26,6 +26,14 @@ public: uint8_t secret[PUB_KEY_SIZE]; }; +/** + * An abstraction of the data tables needed to be maintained +*/ +class MeshTables { +public: + virtual bool hasSeen(const Packet* packet) = 0; +}; + /** * \brief The next layer in the basic Dispatcher task, Mesh recognises the particular Payload TYPES, * and provides virtual methods for sub-classes on handling incoming, and also preparing outbound Packets. @@ -33,6 +41,7 @@ public: class Mesh : public Dispatcher { RTCClock* _rtc; RNG* _rng; + MeshTables* _tables; protected: DispatcherAction onRecvPacket(Packet* pkt) override; @@ -117,8 +126,8 @@ protected: */ virtual void onAckRecv(Packet* packet, uint32_t ack_crc) { } - Mesh(Radio& radio, MillisecondClock& ms, RNG& rng, RTCClock& rtc, PacketManager& mgr) - : Dispatcher(radio, ms, mgr), _rng(&rng), _rtc(&rtc) + Mesh(Radio& radio, MillisecondClock& ms, RNG& rng, RTCClock& rtc, PacketManager& mgr, MeshTables& tables) + : Dispatcher(radio, ms, mgr), _rng(&rng), _rtc(&rtc), _tables(&tables) { } diff --git a/src/MeshTables.h b/src/MeshTables.h deleted file mode 100644 index 24553e16..00000000 --- a/src/MeshTables.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include - -namespace mesh { - -/** - * An abstraction of the data tables needed to be maintained, for the routing engine. -*/ -class MeshTables { -public: - virtual bool hasForwarded(const uint8_t* packet_hash) const = 0; - virtual void setHasForwarded(const uint8_t* packet_hash) = 0; -}; - -} \ No newline at end of file diff --git a/src/helpers/SimpleMeshTables.h b/src/helpers/SimpleMeshTables.h index 32073957..c98704a7 100644 --- a/src/helpers/SimpleMeshTables.h +++ b/src/helpers/SimpleMeshTables.h @@ -1,56 +1,48 @@ #pragma once -#include +#include #ifdef ESP32 #include #endif -#define MAX_PACKET_HASHES 64 +#define MAX_PACKET_HASHES 128 class SimpleMeshTables : public mesh::MeshTables { - uint8_t _fwd_hashes[MAX_PACKET_HASHES*MAX_HASH_SIZE]; - int _next_fwd_idx; - - int lookupHashIndex(const uint8_t* hash) const { - const uint8_t* sp = _fwd_hashes; - for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { - if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) return i; - } - return -1; - } + uint8_t _hashes[MAX_PACKET_HASHES*MAX_HASH_SIZE]; + int _next_idx; public: SimpleMeshTables() { - memset(_fwd_hashes, 0, sizeof(_fwd_hashes)); - _next_fwd_idx = 0; + memset(_hashes, 0, sizeof(_hashes)); + _next_idx = 0; } #ifdef ESP32 void restoreFrom(File f) { - f.read(_fwd_hashes, sizeof(_fwd_hashes)); - f.read((uint8_t *) &_next_fwd_idx, sizeof(_next_fwd_idx)); + f.read(_hashes, sizeof(_hashes)); + f.read((uint8_t *) &_next_idx, sizeof(_next_idx)); } void saveTo(File f) { - f.write(_fwd_hashes, sizeof(_fwd_hashes)); - f.write((const uint8_t *) &_next_fwd_idx, sizeof(_next_fwd_idx)); + f.write(_hashes, sizeof(_hashes)); + f.write((const uint8_t *) &_next_idx, sizeof(_next_idx)); } #endif - bool hasForwarded(const uint8_t* packet_hash) const override { - int i = lookupHashIndex(packet_hash); - return i >= 0; - } + bool hasSeen(const mesh::Packet* packet) override { + uint8_t hash[MAX_HASH_SIZE]; + packet->calculatePacketHash(hash); - void setHasForwarded(const uint8_t* packet_hash) override { - int i = lookupHashIndex(packet_hash); - if (i >= 0) { - // already in table - } else { - memcpy(&_fwd_hashes[_next_fwd_idx*MAX_HASH_SIZE], packet_hash, MAX_HASH_SIZE); - - _next_fwd_idx = (_next_fwd_idx + 1) % MAX_PACKET_HASHES; // cyclic table + const uint8_t* sp = _hashes; + for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { + if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) return true; } + + memcpy(&_hashes[_next_idx*MAX_HASH_SIZE], hash, MAX_HASH_SIZE); + _next_idx = (_next_idx + 1) % MAX_PACKET_HASHES; // cyclic table + + return false; } + }; diff --git a/src/helpers/SimpleSeenTable.h b/src/helpers/SimpleSeenTable.h deleted file mode 100644 index a33824ec..00000000 --- a/src/helpers/SimpleSeenTable.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include -#include - -#define MAX_PACKET_HASHES 64 - -class SimpleSeenTable { - uint8_t _hashes[MAX_PACKET_HASHES*MAX_HASH_SIZE]; - int _next_idx; - -public: - SimpleSeenTable() { - memset(_hashes, 0, sizeof(_hashes)); - _next_idx = 0; - } - - bool hasSeenPacket(const mesh::Packet* packet) { - uint8_t hash[MAX_HASH_SIZE]; - packet->calculatePacketHash(hash); - - const uint8_t* sp = _hashes; - for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { - if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) return true; - } - - memcpy(&_hashes[_next_idx*MAX_HASH_SIZE], hash, MAX_HASH_SIZE); - _next_idx = (_next_idx + 1) % MAX_PACKET_HASHES; // cyclic table - - return false; - } - -};