diff --git a/examples/simple_repeater/MyMesh.cpp b/examples/simple_repeater/MyMesh.cpp index 1afad18b..21f5a76b 100644 --- a/examples/simple_repeater/MyMesh.cpp +++ b/examples/simple_repeater/MyMesh.cpp @@ -404,6 +404,23 @@ uint32_t MyMesh::getDirectRetransmitDelay(const mesh::Packet *packet) { return getRNG()->nextInt(0, 5*t + 1); } +mesh::DispatcherAction MyMesh::onRecvPacket(mesh::Packet* pkt) { + if (pkt->getRouteType() == ROUTE_TYPE_TRANSPORT_FLOOD) { + auto region = region_map.findMatch(pkt, REGION_ALLOW_FLOOD); + if (region == NULL) { + MESH_DEBUG_PRINTLN("onRecvPacket: unknown transport code for FLOOD packet"); + return ACTION_RELEASE; + } + } else if (pkt->getRouteType() == ROUTE_TYPE_FLOOD) { + if ((region_map.getWildcard().flags & REGION_ALLOW_FLOOD) == 0) { + MESH_DEBUG_PRINTLN("onRecvPacket: wildcard FLOOD packet not allowed"); + return ACTION_RELEASE; + } + } + // otherwise do normal processing + return mesh::Mesh::onRecvPacket(pkt); +} + void MyMesh::onAnonDataRecv(mesh::Packet *packet, const uint8_t *secret, const mesh::Identity &sender, uint8_t *data, size_t len) { if (packet->getPayloadType() == PAYLOAD_TYPE_ANON_REQ) { // received an initial request by a possible admin @@ -593,7 +610,7 @@ bool MyMesh::onPeerPathRecv(mesh::Packet *packet, int sender_idx, const uint8_t MyMesh::MyMesh(mesh::MainBoard &board, mesh::Radio &radio, mesh::MillisecondClock &ms, mesh::RNG &rng, mesh::RTCClock &rtc, mesh::MeshTables &tables) : mesh::Mesh(radio, ms, rng, rtc, *new StaticPoolPacketManager(32), tables), - _cli(board, rtc, sensors, &_prefs, this), telemetry(MAX_PACKET_PAYLOAD - 4) + _cli(board, rtc, sensors, &_prefs, this), telemetry(MAX_PACKET_PAYLOAD - 4), region_map(key_store) #if defined(WITH_RS232_BRIDGE) , bridge(&_prefs, WITH_RS232_BRIDGE, _mgr, &rtc) #endif @@ -652,8 +669,9 @@ void MyMesh::begin(FILESYSTEM *fs) { _fs = fs; // load persisted prefs _cli.loadPrefs(_fs); - acl.load(_fs); + // TODO: key_store.begin(); + region_map.load(_fs); #if defined(WITH_BRIDGE) if (_prefs.bridge_enabled) { diff --git a/examples/simple_repeater/MyMesh.h b/examples/simple_repeater/MyMesh.h index 694e8ff9..7e34ef5f 100644 --- a/examples/simple_repeater/MyMesh.h +++ b/examples/simple_repeater/MyMesh.h @@ -32,6 +32,7 @@ #include #include #include +#include #ifdef WITH_BRIDGE extern AbstractBridge* bridge; @@ -87,6 +88,8 @@ class MyMesh : public mesh::Mesh, public CommonCLICallbacks { CommonCLI _cli; uint8_t reply_data[MAX_PACKET_PAYLOAD]; ClientACL acl; + TransportKeyStore key_store; + RegionMap region_map; unsigned long dirty_contacts_expiry; #if MAX_NEIGHBOURS NeighbourInfo neighbours[MAX_NEIGHBOURS]; @@ -144,6 +147,8 @@ protected: } #endif + mesh::DispatcherAction onRecvPacket(mesh::Packet* pkt) override; + void onAnonDataRecv(mesh::Packet* packet, const uint8_t* secret, const mesh::Identity& sender, uint8_t* data, size_t len) override; int searchPeersByHash(const uint8_t* hash) override; void getPeerSharedSecret(uint8_t* dest_secret, int peer_idx) override; diff --git a/src/helpers/RegionMap.cpp b/src/helpers/RegionMap.cpp new file mode 100644 index 00000000..4c270ff8 --- /dev/null +++ b/src/helpers/RegionMap.cpp @@ -0,0 +1,35 @@ +#include "RegionMap.h" +#include + +void RegionMap::load(FILESYSTEM* _fs) { + // TODO +} +void RegionMap::save(FILESYSTEM* _fs) { + // TODO +} + +RegionEntry* RegionMap::findMatch(mesh::Packet* packet, uint8_t mask) { + for (int i = 0; i < num_regions; i++) { + auto region = ®ions[i]; + if (region->flags & mask) { // does region allow this? (per 'mask' param) + TransportKey keys[4]; + int num = _store->loadKeysFor(region->name, region->id, keys, 4); + for (int j = 0; j < num; j++) { + uint16_t code = keys[j].calcTransportCode(packet); + if (packet->transport_codes[0] == code) { // a match!! + return region; + } + } + } + } + return NULL; // no matches +} + +const RegionEntry* RegionMap::findName(const char* name) const { + for (int i = 0; i < num_regions; i++) { + auto region = ®ions[i]; + if (strcmp(name, region->name) == 0) return region; + } + return NULL; // not found +} + diff --git a/src/helpers/RegionMap.h b/src/helpers/RegionMap.h new file mode 100644 index 00000000..4620a745 --- /dev/null +++ b/src/helpers/RegionMap.h @@ -0,0 +1,39 @@ +#pragma once + +#include // needed for PlatformIO +#include +#include "TransportKeyStore.h" + +#ifndef MAX_REGION_ENTRIES + #define MAX_REGION_ENTRIES 32 +#endif + +#define REGION_ALLOW_FLOOD 0x01 + +struct RegionEntry { + uint16_t id; + uint16_t parent; + uint8_t flags; + char name[31]; +}; + +class RegionMap { + TransportKeyStore* _store; + uint16_t next_id; + uint16_t num_regions; + RegionEntry regions[MAX_REGION_ENTRIES]; + RegionEntry wildcard; + +public: + RegionMap(TransportKeyStore& store) : _store(&store) { + next_id = 1; num_regions = 0; + wildcard.id = wildcard.parent = 0; + wildcard.flags = REGION_ALLOW_FLOOD; // default behaviour, allow flood + } + void load(FILESYSTEM* _fs); + void save(FILESYSTEM* _fs); + + RegionEntry* findMatch(mesh::Packet* packet, uint8_t mask); + const RegionEntry& getWildcard() const { return wildcard; } + const RegionEntry* findName(const char* name) const; +}; diff --git a/src/helpers/TransportKeyStore.cpp b/src/helpers/TransportKeyStore.cpp new file mode 100644 index 00000000..973ad703 --- /dev/null +++ b/src/helpers/TransportKeyStore.cpp @@ -0,0 +1,44 @@ +#include "TransportKeyStore.h" +#include + +uint16_t TransportKey::calcTransportCode(const mesh::Packet* packet) const { + uint16_t code; + SHA256 sha; + sha.resetHMAC(key, sizeof(key)); + uint8_t type = packet->getPayloadType(); + sha.update(&type, 1); + sha.update(packet->payload, packet->payload_len); + sha.finalizeHMAC(key, sizeof(key), &code, 2); + return code; +} + +int TransportKeyStore::loadKeysFor(const char* name, uint16_t id, TransportKey keys[], int max_num) { + int n = 0; + for (int i = 0; i < num_cache && n < max_num; i++) { // first, check cache + if (cache_ids[i] == id) { + keys[n++] = cache_keys[i]; + } + } + if (n > 0) return n; // cache hit! + + if (*name == '#') { // is a publicly-known hashtag region + SHA256 sha; + sha.update(name, strlen(name)); + sha.finalize(&keys[0], sizeof(keys[0].key)); + n = 1; + } else { + // TODO: retrieve from difficult-to-copy keystore + } + + // store in cache (if room) + for (int i = 0; i < n; i++) { + if (num_cache < MAX_TKS_ENTRIES) { + cache_ids[num_cache] = id; + cache_keys[num_cache] = keys[i]; + num_cache++; + } else { + // TODO: evict oldest cache entry + } + } + return n; +} diff --git a/src/helpers/TransportKeyStore.h b/src/helpers/TransportKeyStore.h new file mode 100644 index 00000000..f25ed53f --- /dev/null +++ b/src/helpers/TransportKeyStore.h @@ -0,0 +1,23 @@ +#pragma once + +#include // needed for PlatformIO +#include +#include + +struct TransportKey { + uint8_t key[16]; + + uint16_t calcTransportCode(const mesh::Packet* packet) const; +}; + +#define MAX_TKS_ENTRIES 16 + +class TransportKeyStore { + uint16_t cache_ids[MAX_TKS_ENTRIES]; + TransportKey cache_keys[MAX_TKS_ENTRIES]; + int num_cache; + +public: + TransportKeyStore() { num_cache = 0; } + int loadKeysFor(const char* name, uint16_t id, TransportKey keys[], int max_num); +};