diff --git a/platformio.ini b/platformio.ini index b08854e5..47cc0ab8 100644 --- a/platformio.ini +++ b/platformio.ini @@ -165,5 +165,6 @@ test_build_src = yes build_src_filter = -<*> +<../src/Utils.cpp> + +<../src/Packet.cpp> lib_deps = google/googletest @ 1.17.0 diff --git a/src/Mesh.cpp b/src/Mesh.cpp index e9b92262..c11f37ca 100644 --- a/src/Mesh.cpp +++ b/src/Mesh.cpp @@ -55,7 +55,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint16_t offset = (uint16_t)pkt->path_len << path_sz; if (offset >= len) { // TRACE has reached end of given path onTraceRecv(pkt, trace_tag, auth_code, flags, pkt->path, &pkt->payload[i], len); - } else if (self_id.isHashMatch(&pkt->payload[i + offset], 1 << path_sz) && allowPacketForward(pkt) && !_tables->hasSeen(pkt)) { + } else if (self_id.isHashMatch(&pkt->payload[i + offset], 1 << path_sz) && allowPacketForward(pkt) && !_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); // append SNR (Not hash!) pkt->path[pkt->path_len++] = (int8_t) (pkt->getSNR()*4); @@ -89,14 +90,16 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { if (pkt->getPayloadType() == PAYLOAD_TYPE_MULTIPART) { return forwardMultipartDirect(pkt); } else if (pkt->getPayloadType() == PAYLOAD_TYPE_ACK) { - if (!_tables->hasSeen(pkt)) { // don't retransmit! + if (!_tables->wasSeen(pkt)) { // don't retransmit! + _tables->markSeen(pkt); removeSelfFromPath(pkt); routeDirectRecvAcks(pkt, 0); } return ACTION_RELEASE; } - if (!_tables->hasSeen(pkt)) { + if (!_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); removeSelfFromPath(pkt); uint32_t d = getDirectRetransmitDelay(pkt); @@ -117,7 +120,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { memcpy(&ack_crc, &pkt->payload[i], 4); i += 4; if (i > pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete ACK packet", getLogDateTime()); - } else if (!_tables->hasSeen(pkt)) { + } else if (!_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); onAckRecv(pkt, ack_crc); action = routeRecvPacket(pkt); } @@ -134,7 +138,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + CIPHER_MAC_SIZE >= pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime()); - } else if (!_tables->hasSeen(pkt)) { + } else if (!_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); // NOTE: this is a 'first packet wins' impl. When receiving from multiple paths, the first to arrive wins. // For flood mode, the path may not be the 'best' in terms of hops. // FUTURE: could send back multiple paths, using createPathReturn(), and let sender choose which to use(?) @@ -197,7 +202,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime()); - } else if (!_tables->hasSeen(pkt)) { + } else if (!_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); if (self_id.isHashMatch(&dest_hash)) { Identity sender(sender_pub_key); @@ -224,7 +230,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { uint8_t* macAndData = &pkt->payload[i]; // MAC + encrypted data if (i + 2 >= pkt->payload_len) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete data packet", getLogDateTime()); - } else if (!_tables->hasSeen(pkt)) { + } else if (!_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); // scan channels DB, for all matching hashes of 'channel_hash' (max 4 matches supported ATM) GroupChannel channels[4]; int num = searchChannelsByHash(&channel_hash, channels, 4); @@ -255,7 +262,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): incomplete advertisement packet", getLogDateTime()); } else if (self_id.matches(id.pub_key)) { MESH_DEBUG_PRINTLN("%s Mesh::onRecvPacket(): receiving SELF advert packet", getLogDateTime()); - } else if (!_tables->hasSeen(pkt)) { + } else if (!_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); uint8_t* app_data = &pkt->payload[i]; int app_data_len = pkt->payload_len - i; if (app_data_len > MAX_ADVERT_DATA_SIZE) { app_data_len = MAX_ADVERT_DATA_SIZE; } @@ -282,7 +290,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { break; } case PAYLOAD_TYPE_RAW_CUSTOM: { - if (pkt->isRouteDirect() && !_tables->hasSeen(pkt)) { + if (pkt->isRouteDirect() && !_tables->wasSeen(pkt)) { + _tables->markSeen(pkt); onRawDataRecv(pkt); //action = routeRecvPacket(pkt); don't flood route these (yet) } @@ -300,7 +309,8 @@ DispatcherAction Mesh::onRecvPacket(Packet* pkt) { tmp.payload_len = pkt->payload_len - 1; memcpy(tmp.payload, &pkt->payload[1], tmp.payload_len); - if (!_tables->hasSeen(&tmp)) { + if (!_tables->wasSeen(&tmp)) { + _tables->markSeen(&tmp); uint32_t ack_crc; memcpy(&ack_crc, tmp.payload, 4); @@ -357,7 +367,8 @@ DispatcherAction Mesh::forwardMultipartDirect(Packet* pkt) { tmp.payload_len = pkt->payload_len - 1; memcpy(tmp.payload, &pkt->payload[1], tmp.payload_len); - if (!_tables->hasSeen(&tmp)) { // don't retransmit! + if (!_tables->wasSeen(&tmp)) { // don't retransmit! + _tables->markSeen(&tmp); removeSelfFromPath(&tmp); routeDirectRecvAcks(&tmp, ((uint32_t)remaining + 1) * 300); // expect multipart ACKs 300ms apart (x2) } @@ -637,7 +648,7 @@ void Mesh::sendFlood(Packet* packet, uint32_t delay_millis, uint8_t path_hash_si packet->header |= ROUTE_TYPE_FLOOD; packet->setPathHashSizeAndCount(path_hash_size, 0); - _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us uint8_t pri; if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) { @@ -666,7 +677,7 @@ void Mesh::sendFlood(Packet* packet, uint16_t* transport_codes, uint32_t delay_m packet->transport_codes[1] = transport_codes[1]; packet->setPathHashSizeAndCount(path_hash_size, 0); - _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us uint8_t pri; if (packet->getPayloadType() == PAYLOAD_TYPE_PATH) { @@ -699,7 +710,7 @@ void Mesh::sendDirect(Packet* packet, const uint8_t* path, uint8_t path_len, uin pri = 0; } } - _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us sendPacket(packet, pri, delay_millis); } @@ -709,7 +720,7 @@ void Mesh::sendZeroHop(Packet* packet, uint32_t delay_millis) { packet->path_len = 0; // path_len of zero means Zero Hop - _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us sendPacket(packet, 0, delay_millis); } @@ -722,7 +733,7 @@ void Mesh::sendZeroHop(Packet* packet, uint16_t* transport_codes, uint32_t delay packet->path_len = 0; // path_len of zero means Zero Hop - _tables->hasSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us + _tables->markSeen(packet); // mark this packet as already sent in case it is rebroadcast back to us sendPacket(packet, 0, delay_millis); } diff --git a/src/Mesh.h b/src/Mesh.h index 932541db..49a299a6 100644 --- a/src/Mesh.h +++ b/src/Mesh.h @@ -15,8 +15,9 @@ public: */ class MeshTables { public: - virtual bool hasSeen(const Packet* packet) = 0; - virtual void clear(const Packet* packet) = 0; // remove this packet hash from table + virtual bool wasSeen(const Packet* packet) = 0; + virtual void markSeen(const Packet* packet) = 0; + virtual void clear(const Packet* packet) = 0; // remove this packet hash from table }; /** diff --git a/src/helpers/SimpleMeshTables.h b/src/helpers/SimpleMeshTables.h index 0b79cfb4..956f36fa 100644 --- a/src/helpers/SimpleMeshTables.h +++ b/src/helpers/SimpleMeshTables.h @@ -31,27 +31,31 @@ public: } #endif - bool hasSeen(const mesh::Packet* packet) override { + bool wasSeen(const mesh::Packet* packet) override { uint8_t hash[MAX_HASH_SIZE]; packet->calculatePacketHash(hash); const uint8_t* sp = _hashes; for (int i = 0; i < MAX_PACKET_HASHES; i++, sp += MAX_HASH_SIZE) { - if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) { + if (memcmp(hash, sp, MAX_HASH_SIZE) == 0) { if (packet->isRouteDirect()) { - _direct_dups++; // keep some stats + _direct_dups++; } else { _flood_dups++; } return true; } } - - memcpy(&_hashes[_next_idx*MAX_HASH_SIZE], hash, MAX_HASH_SIZE); - _next_idx = (_next_idx + 1) % MAX_PACKET_HASHES; // cyclic table return false; } + void markSeen(const mesh::Packet* packet) override { + uint8_t hash[MAX_HASH_SIZE]; + packet->calculatePacketHash(hash); + memcpy(&_hashes[_next_idx * MAX_HASH_SIZE], hash, MAX_HASH_SIZE); + _next_idx = (_next_idx + 1) % MAX_PACKET_HASHES; + } + void clear(const mesh::Packet* packet) override { uint8_t hash[MAX_HASH_SIZE]; packet->calculatePacketHash(hash); diff --git a/src/helpers/bridges/BridgeBase.cpp b/src/helpers/bridges/BridgeBase.cpp index d2e2e5e0..8093d3cb 100644 --- a/src/helpers/bridges/BridgeBase.cpp +++ b/src/helpers/bridges/BridgeBase.cpp @@ -39,7 +39,8 @@ void BridgeBase::handleReceivedPacket(mesh::Packet *packet) { return; } - if (!_seen_packets.hasSeen(packet)) { + if (!_seen_packets.wasSeen(packet)) { + _seen_packets.markSeen(packet); // bridge_delay provides a buffer to prevent immediate processing conflicts in the mesh network. _mgr->queueInbound(packet, millis() + _prefs->bridge_delay); } else { diff --git a/src/helpers/bridges/BridgeBase.h b/src/helpers/bridges/BridgeBase.h index 04c1564b..8bbe6466 100644 --- a/src/helpers/bridges/BridgeBase.h +++ b/src/helpers/bridges/BridgeBase.h @@ -110,7 +110,7 @@ protected: * @brief Common packet handling for received packets * * Implements the standard pattern used by all bridges: - * - Check if packet was seen before using _seen_packets.hasSeen() + * - Check if packet was seen before using _seen_packets.wasSeen() * - Queue packet for mesh processing if not seen before * - Free packet if already seen to prevent duplicates * diff --git a/src/helpers/bridges/ESPNowBridge.cpp b/src/helpers/bridges/ESPNowBridge.cpp index 808e9df4..3c094e7c 100644 --- a/src/helpers/bridges/ESPNowBridge.cpp +++ b/src/helpers/bridges/ESPNowBridge.cpp @@ -167,7 +167,8 @@ void ESPNowBridge::sendPacket(mesh::Packet *packet) { return; } - if (!_seen_packets.hasSeen(packet)) { + if (!_seen_packets.wasSeen(packet)) { + _seen_packets.markSeen(packet); // Create a temporary buffer just for size calculation and reuse for actual writing uint8_t sizingBuffer[MAX_PAYLOAD_SIZE]; uint16_t meshPacketLen = packet->writeTo(sizingBuffer); diff --git a/src/helpers/bridges/RS232Bridge.cpp b/src/helpers/bridges/RS232Bridge.cpp index 0024f6f2..f719d342 100644 --- a/src/helpers/bridges/RS232Bridge.cpp +++ b/src/helpers/bridges/RS232Bridge.cpp @@ -115,7 +115,8 @@ void RS232Bridge::sendPacket(mesh::Packet *packet) { return; } - if (!_seen_packets.hasSeen(packet)) { + if (!_seen_packets.wasSeen(packet)) { + _seen_packets.markSeen(packet); uint8_t buffer[MAX_SERIAL_PACKET_SIZE]; uint16_t len = packet->writeTo(buffer + 4); diff --git a/test/mocks/SHA256.h b/test/mocks/SHA256.h index b6e551a0..0cefe5c6 100644 --- a/test/mocks/SHA256.h +++ b/test/mocks/SHA256.h @@ -3,12 +3,33 @@ #include #include -// Mock SHA256 class for testing -// Provides minimal interface to allow Utils.cpp to compile +// Mock SHA256 for native testing — deterministic but not cryptographic. +// finalize() writes real (non-garbage) output so calculatePacketHash() produces +// distinguishable results for packets with different payloads. +#include + class SHA256 { + uint8_t _state[32]; + size_t _len; public: - void update(const uint8_t* data, size_t len) {} - void finalize(uint8_t* hash, size_t hashLen) {} + SHA256() : _len(0) { memset(_state, 0, sizeof(_state)); } + + void update(const void* data, size_t len) { + const uint8_t* bytes = static_cast(data); + for (size_t i = 0; i < len; i++) { + uint8_t b = bytes[i]; + _state[_len % 32] ^= b; + _state[(_len + 1) % 32] += (uint8_t)((b >> 1) | (b << 7)); + _len++; + } + } + + void finalize(uint8_t* hash, size_t hashLen) { + for (size_t i = 0; i < hashLen; i++) { + hash[i] = _state[i % 32]; + } + } + void resetHMAC(const uint8_t* key, size_t keyLen) {} void finalizeHMAC(const uint8_t* key, size_t keyLen, uint8_t* hash, size_t hashLen) {} }; diff --git a/test/test_mesh_tables/test_simple_mesh_tables.cpp b/test/test_mesh_tables/test_simple_mesh_tables.cpp new file mode 100644 index 00000000..46b477d9 --- /dev/null +++ b/test/test_mesh_tables/test_simple_mesh_tables.cpp @@ -0,0 +1,103 @@ +#include +#include "helpers/SimpleMeshTables.h" + +using namespace mesh; + +// Build a packet that calculatePacketHash() distinguishes by payload content. +// header selects ROUTE_TYPE_FLOOD so isRouteDirect() returns false. +static Packet makeFloodPacket(uint8_t seed) { + Packet p; + p.header = ROUTE_TYPE_FLOOD | (PAYLOAD_TYPE_ACK << PH_TYPE_SHIFT); + p.payload[0] = seed; + p.payload_len = 1; + p.path_len = 0; + return p; +} + +static Packet makeDirectPacket(uint8_t seed) { + Packet p; + p.header = ROUTE_TYPE_DIRECT | (PAYLOAD_TYPE_ACK << PH_TYPE_SHIFT); + p.payload[0] = seed; + p.payload_len = 1; + p.path_len = 0; + return p; +} + +// ── wasSeen: pure query ─────────────────────────────────────────────────────── + +TEST(SimpleMeshTables, WasSeen_ReturnsFalseForUnseen) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + EXPECT_FALSE(t.wasSeen(&p)); +} + +// wasSeen shouldn't change state +TEST(SimpleMeshTables, WasSeen_IsPureQuery_DoesNotInsert) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + EXPECT_FALSE(t.wasSeen(&p)); + EXPECT_FALSE(t.wasSeen(&p)); +} + +// ── markSeen + wasSeen ─────────────────────────────────────────────────────── + +TEST(SimpleMeshTables, MarkSeen_MakesWasSeenReturnTrue) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + t.markSeen(&p); + EXPECT_TRUE(t.wasSeen(&p)); +} + +TEST(SimpleMeshTables, MarkSeen_DoesNotAffectOtherPackets) { + SimpleMeshTables t; + Packet p1 = makeFloodPacket(0x01); + Packet p2 = makeFloodPacket(0x02); + t.markSeen(&p1); + EXPECT_FALSE(t.wasSeen(&p2)); +} + +// Canonical pattern used at every onRecvPacket call site: +// if (!wasSeen(pkt)) { markSeen(pkt); process(pkt); } +TEST(SimpleMeshTables, QueryThenMark_WorksCorrectly) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + EXPECT_FALSE(t.wasSeen(&p)); + t.markSeen(&p); + EXPECT_TRUE(t.wasSeen(&p)); +} + +// ── dup stats ──────────────────────────────────────────────────────────────── + +TEST(SimpleMeshTables, WasSeen_IncrementsFloodDupStat) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + t.markSeen(&p); + t.wasSeen(&p); + EXPECT_EQ(1u, t.getNumFloodDups()); + EXPECT_EQ(0u, t.getNumDirectDups()); +} + +TEST(SimpleMeshTables, WasSeen_IncrementsDirectDupStat) { + SimpleMeshTables t; + Packet p = makeDirectPacket(0x01); + t.markSeen(&p); + t.wasSeen(&p); + EXPECT_EQ(0u, t.getNumFloodDups()); + EXPECT_EQ(1u, t.getNumDirectDups()); +} + +// ── clear ──────────────────────────────────────────────────────────────────── + +TEST(SimpleMeshTables, Clear_RemovesSeenPacket) { + SimpleMeshTables t; + Packet p = makeFloodPacket(0x01); + t.markSeen(&p); + ASSERT_TRUE(t.wasSeen(&p)); + t.clear(&p); + EXPECT_FALSE(t.wasSeen(&p)); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}