diff --git a/LXMF/LXMessage.py b/LXMF/LXMessage.py index 515ab11..4739f30 100644 --- a/LXMF/LXMessage.py +++ b/LXMF/LXMessage.py @@ -268,15 +268,6 @@ class LXMessage: def register_failed_callback(self, callback): self.failed_callback = callback - @staticmethod - def stamp_valid(stamp, target_cost, workblock): - target = 0b1 << 256-target_cost - result = RNS.Identity.full_hash(workblock+stamp) - if int.from_bytes(result, byteorder="big") > target: - return False - else: - return True - def validate_stamp(self, target_cost, tickets=None): if tickets != None: for ticket in tickets: @@ -293,7 +284,7 @@ class LXMessage: return False else: workblock = LXStamper.stamp_workblock(self.message_id) - if LXMessage.stamp_valid(self.stamp, target_cost, workblock): + if LXStamper.stamp_valid(self.stamp, target_cost, workblock): RNS.log(f"Stamp on {self} validated", RNS.LOG_DEBUG) # TODO: Remove at some point self.stamp_value = LXStamper.stamp_value(workblock, self.stamp) return True diff --git a/LXMF/LXStamper.py b/LXMF/LXStamper.py index bcfa95b..a9ca7d6 100644 --- a/LXMF/LXStamper.py +++ b/LXMF/LXStamper.py @@ -3,15 +3,18 @@ import RNS.vendor.umsgpack as msgpack import os import time +import math import multiprocessing -WORKBLOCK_EXPAND_ROUNDS = 3000 +WORKBLOCK_EXPAND_ROUNDS = 3000 +WORKBLOCK_EXPAND_ROUNDS_PN = 1000 +STAMP_SIZE = RNS.Identity.HASHLENGTH +PN_VALIDATION_POOL_MIN_SIZE = 256 active_jobs = {} -def stamp_workblock(message_id): +def stamp_workblock(message_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS): wb_st = time.time() - expand_rounds = WORKBLOCK_EXPAND_ROUNDS workblock = b"" for n in range(expand_rounds): workblock += RNS.Cryptography.hkdf( @@ -21,7 +24,7 @@ def stamp_workblock(message_id): context=None, ) wb_time = time.time() - wb_st - RNS.log(f"Stamp workblock size {RNS.prettysize(len(workblock))}, generated in {round(wb_time*1000,2)}ms", RNS.LOG_DEBUG) + # RNS.log(f"Stamp workblock size {RNS.prettysize(len(workblock))}, generated in {round(wb_time*1000,2)}ms", RNS.LOG_DEBUG) return workblock @@ -36,6 +39,53 @@ def stamp_value(workblock, stamp): return value +def stamp_valid(stamp, target_cost, workblock): + target = 0b1 << 256-target_cost + result = RNS.Identity.full_hash(workblock+stamp) + if int.from_bytes(result, byteorder="big") > target: return False + else: return True + +def validate_pn_stamp(transient_id, stamp): + target_cost = 8 + workblock = stamp_workblock(transient_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN) + if stamp_valid(stamp, target_cost, workblock): + RNS.log(f"Stamp on {RNS.prettyhexrep(transient_id)} validated", RNS.LOG_DEBUG) + value = stamp_value(workblock, stamp) + return True + + return False + +def validate_pn_stamps_job_simple(transient_stamps): + for entry in transient_stamps: + # Get transient ID and stamp for validation + transient_id = transient_stamps[0] + stamp = transient_stamps[1] + + # Store validation result back into list + transient_stamps[2] = validate_pn_stamp(transient_id, stamp) + + return transient_stamps + +def _validate_single_pn_stamp_entry(entry): + transient_id = entry[0] + stamp = entry[1] + entry[2] = validate_pn_stamp(transient_id, stamp) + return entry + +def validate_pn_stamps_job_multip(transient_stamps): + cores = multiprocessing.cpu_count() + pool_count = min(cores, math.ceil(len(transient_stamps) / PN_VALIDATION_POOL_MIN_SIZE)) + + RNS.log(f"Validating {len(transient_stamps)} stamps using {pool_count} processes...") + with multiprocessing.Pool(pool_count) as p: validated_entries = p.map(_validate_single_pn_stamp_entry, transient_stamps) + + return validated_entries + +def validate_pn_stamps(transient_stamps): + non_mp_platform = RNS.vendor.platformutils.is_android() + if len(transient_stamps) <= PN_VALIDATION_POOL_MIN_SIZE or non_mp_platform: validate_pn_stamps_job_simple(transient_stamps) + else: validate_pn_stamps_job_multip(transient_stamps) + def generate_stamp(message_id, stamp_cost): RNS.log(f"Generating stamp with cost {stamp_cost} for {RNS.prettyhexrep(message_id)}...", RNS.LOG_DEBUG) workblock = stamp_workblock(message_id)