#!/usr/bin/env python3

# Reticulum License
#
# Copyright (c) 2016-2025 Mark Qvist
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# - The Software shall not be used in any kind of system which includes amongst
#   its functions the ability to purposefully do harm to human beings.
#
# - The Software shall not be used, directly or indirectly, in the creation of
#   an artificial intelligence, machine learning or language model training
#   dataset, including but not limited to any use that contributes to the
#   training or development of such a model or algorithm.
#
# - The above copyright notice and this permission notice shall be included in
#   all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import RNS
import os
import sys
import time
import argparse

from RNS._version import __version__

remote_link = None
output_rst_str = "\r                                                          \r"
def connect_remote(destination_hash, auth_identity, timeout, no_output = False, purpose="management"):
    global remote_link, reticulum
    if not RNS.Transport.has_path(destination_hash):
        if not no_output:
            print("Path to "+RNS.prettyhexrep(destination_hash)+" requested", end=" ")
            sys.stdout.flush()
        RNS.Transport.request_path(destination_hash)
        pr_time = time.time()
        while not RNS.Transport.has_path(destination_hash):
            time.sleep(0.1)
            if time.time() - pr_time > timeout:
                if not no_output:
                    print(output_rst_str, end="")
                    print("Path request timed out")
                    exit(12)

    remote_identity = RNS.Identity.recall(destination_hash)

    def remote_link_closed(link):
        if link.teardown_reason == RNS.Link.TIMEOUT:
            if not no_output:
                print(output_rst_str, end="")
                print("The link timed out, exiting now")
        elif link.teardown_reason == RNS.Link.DESTINATION_CLOSED:
            if not no_output:
                print(output_rst_str, end="")
                print("The link was closed by the server, exiting now")
        else:
            if not no_output:
                print(output_rst_str, end="")
                print("Link closed unexpectedly, exiting now")
        exit(10)

    def remote_link_established(link):
        global remote_link
        if purpose == "management": link.identify(auth_identity)
        remote_link = link

    if not no_output:
        print(output_rst_str, end="")
        print("Establishing link with remote transport instance...", end=" ")
        sys.stdout.flush()

    if purpose == "management": remote_destination = RNS.Destination(remote_identity, RNS.Destination.OUT, RNS.Destination.SINGLE, "rnstransport", "remote", "management")
    elif purpose == "blackhole": remote_destination = RNS.Destination(remote_identity, RNS.Destination.OUT, RNS.Destination.SINGLE, "rnstransport", "info", "blackhole")
    link = RNS.Link(remote_destination)
    link.set_link_established_callback(remote_link_established)
    link.set_link_closed_callback(remote_link_closed)

def parse_hash(input_str):
    dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
    if len(input_str) != dest_len: raise ValueError("Hash length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
    try:
        hash_bytes = bytes.fromhex(input_str)
        return hash_bytes
    except Exception as e: raise ValueError("Invalid hash entered. Check your input.")

def program_setup(configdir, table, rates, drop, destination_hexhash, verbosity, timeout, drop_queues,
                  drop_via, max_hops, remote=None, management_identity=None, remote_timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
                  blackholed=False, blackhole=False, unblackhole=False, blackhole_duration=None, blackhole_reason=None,
                  remote_blackhole_list=False, remote_blackhole_list_filter=None, no_output=False, json=False):

    global remote_link, reticulum
    reticulum = RNS.Reticulum(configdir = configdir, loglevel = 3+verbosity)
    if remote:
        try:
            dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
            if len(remote) != dest_len: raise ValueError("Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
            try:
                identity_hash = bytes.fromhex(remote)
                remote_hash = RNS.Destination.hash_from_name_and_identity("rnstransport.remote.management", identity_hash)
            except Exception as e: raise ValueError("Invalid destination entered. Check your input.")

            identity = RNS.Identity.from_file(os.path.expanduser(management_identity))
            if identity == None: raise ValueError("Could not load management identity from "+str(management_identity))

            try: connect_remote(remote_hash, identity, remote_timeout, no_output)
            except Exception as e: raise e

        except Exception as e:
            print(str(e))
            exit(20)

        while remote_link == None: time.sleep(0.1)

    if blackholed or remote_blackhole_list:
        blackholed_list = None
        if blackholed:
            if remote_link:
                if not no_output:
                    print(output_rst_str, end="")
                    print("Listing blackholed identities on remote instances not yet implemented")
                exit(255)

            try: blackholed_list = reticulum.get_blackholed_identities()
            except Exception as e:
                print(f"Could not get blackholed identities from RNS instance: {e}")
                exit(20)

        elif remote_blackhole_list:
            try: identity_hash = parse_hash(destination_hexhash)
            except Exception as e:
                print(f"{e}")
                exit(20)

            remote_hash = RNS.Destination.hash_from_name_and_identity("rnstransport.info.blackhole", identity_hash)
            connect_remote(remote_hash, None, remote_timeout, no_output, purpose="blackhole")
            while remote_link == None: time.sleep(0.1)

            if not no_output:
                print(output_rst_str, end="")
                print("Sending request...", end=" ")
                sys.stdout.flush()
            receipt = remote_link.request("/list")
            while not receipt.concluded(): time.sleep(0.1)
            response = receipt.get_response()
            if type(response) == dict:
                blackholed_list = response
                print(output_rst_str, end="")
            else:
                if not no_output:
                    print(output_rst_str, end="")
                    print("The remote request failed.")
                exit(10)

        else:
            print(f"Nowhere to fetch blackhole list from")
            exit(255)

        if not blackholed_list:
            print("No blackholed identity data available")
            exit(20)

        else:
            rmlen = 64
            def trunc(input_str):
                if len(input_str) <= rmlen: return input_str
                else: return f"{input_str[:rmlen-1]}…"

            try:
                now = time.time()
                for identity_hash in blackholed_list:
                    until      = blackholed_list[identity_hash]["until"]
                    reason     = blackholed_list[identity_hash]["reason"]
                    source     = blackholed_list[identity_hash]["source"]
                    until_str  = f"for {RNS.prettytime(max(0, until-now))}" if until else "indefinitely"
                    reason_str = f" ({trunc(reason)})" if reason else ""
                    by_str     = f" by {RNS.prettyhexrep(source)}" if source != RNS.Transport.identity.hash else ""
                    filter_str = f"{RNS.prettyhexrep(identity_hash)} {until_str} {reason_str} {by_str}"

                    if not remote_blackhole_list:
                        if destination_hexhash and not destination_hexhash in filter_str: continue
                    else:
                        if remote_blackhole_list_filter and not remote_blackhole_list_filter in filter_str: continue

                    print(f"{RNS.prettyhexrep(identity_hash)} blackholed {until_str}{reason_str}{by_str}")

            except Exception as e:
                print(f"Error while displaying collected blackhole data: {e}")
                exit(20)

    elif blackhole:
        if remote_link:
            if not no_output:
                print(output_rst_str, end="")
                print("Blackholing identity on remote instances not yet implemented")
            exit(255)

        try:
            identity_hash = parse_hash(destination_hexhash)
            until = time.time()+blackhole_duration*60*60 if blackhole_duration else None
            result = reticulum.blackhole_identity(identity_hash, until=until, reason=blackhole_reason)
            if result == True: print(f"Blackholed identity {destination_hexhash}")
            elif result == None: print(f"Identity {destination_hexhash} already blackholed")
            else: print(f"Could not blackhole identity {destination_hexhash}")
        
        except Exception as e:
            print(f"Could not blackhole identity: {e}")
            exit(20)
    
    elif unblackhole:
        if remote_link:
            if not no_output:
                print(output_rst_str, end="")
                print("Blackholing identity on remote instances not yet implemented")
            exit(255)

        try:
            identity_hash = parse_hash(destination_hexhash)
            result = reticulum.unblackhole_identity(identity_hash)
            if result == True: print(f"Lifted blackhole for identity {destination_hexhash}")
            elif result == None: print(f"Identity {destination_hexhash} not blackholed")
            else: print(f"Could not unblackhole identity {destination_hexhash}")
        
        except Exception as e:
            print(f"Could not unblackhole identity: {e}")
            exit(20)
    
    elif table:
        destination_hash = None
        if destination_hexhash != None:
            try:
                dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
                if len(destination_hexhash) != dest_len: raise ValueError("Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
                try: destination_hash = bytes.fromhex(destination_hexhash)
                except Exception as e: raise ValueError("Invalid destination entered. Check your input.")
            except Exception as e:
                print(str(e))
                sys.exit(1)

        if not remote_link: table = sorted(reticulum.get_path_table(max_hops=max_hops), key=lambda e: (e["interface"], e["hops"]) )
        else:
            if not no_output:
                print(output_rst_str, end="")
                print("Sending request...", end=" ")
                sys.stdout.flush()
            receipt = remote_link.request("/path", data = ["table", destination_hash, max_hops])
            while not receipt.concluded(): time.sleep(0.1)
            response = receipt.get_response()
            if response:
                table = response
                print(output_rst_str, end="")
            else:
                if not no_output:
                    print(output_rst_str, end="")
                    print("The remote request failed. Likely authentication failure.")
                exit(10)

        displayed = 0
        if json:
            import json
            for p in table:
                for k in p:
                    if isinstance(p[k], bytes): p[k] = RNS.hexrep(p[k], delimit=False)

            print(json.dumps(table))
            exit()
        
        else:
            for path in table:
                if destination_hash == None or destination_hash == path["hash"]:
                    displayed += 1
                    exp_str = RNS.timestamp_str(path["expires"])
                    if path["hops"] == 1: m_str = " "
                    else: m_str = "s"
                    print(RNS.prettyhexrep(path["hash"])+" is "+str(path["hops"])+" hop"+m_str+" away via "+RNS.prettyhexrep(path["via"])+" on "+path["interface"]+" expires "+RNS.timestamp_str(path["expires"]))

            if destination_hash != None and displayed == 0:
                print("No path known")
                sys.exit(1)

    elif rates:
        destination_hash = None
        if destination_hexhash != None:
            try:
                dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
                if len(destination_hexhash) != dest_len: raise ValueError("Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
                try: destination_hash = bytes.fromhex(destination_hexhash)
                except Exception as e: raise ValueError("Invalid destination entered. Check your input.")
            except Exception as e:
                print(str(e))
                sys.exit(1)

        if not remote_link: table = reticulum.get_rate_table()
        else:
            if not no_output:
                print(output_rst_str, end="")
                print("Sending request...", end=" ")
                sys.stdout.flush()
            receipt = remote_link.request("/path", data = ["rates", destination_hash])
            while not receipt.concluded():
                time.sleep(0.1)
            response = receipt.get_response()
            if response:
                table = response
                print(output_rst_str, end="")
            else:
                if not no_output:
                    print(output_rst_str, end="")
                    print("The remote request failed. Likely authentication failure.")
                exit(10)

        table = sorted(table, key=lambda e: e["last"])
        if json:
            import json
            for p in table:
                for k in p:
                    if isinstance(p[k], bytes): p[k] = RNS.hexrep(p[k], delimit=False)

            print(json.dumps(table))
            exit()
        else:
            if len(table) == 0: print("No information available")
            else:
                displayed = 0
                for entry in table:
                    if destination_hash == None or destination_hash == entry["hash"]:
                        displayed += 1
                        try:
                            last_str = pretty_date(int(entry["last"]))
                            start_ts = entry["timestamps"][0]
                            span = max(time.time() - start_ts, 3600.0)
                            span_hours = span/3600.0
                            span_str = pretty_date(int(entry["timestamps"][0]))
                            hour_rate = round(len(entry["timestamps"])/span_hours, 3)
                            if hour_rate-int(hour_rate) == 0:
                                hour_rate = int(hour_rate)
                            
                            if entry["rate_violations"] > 0:
                                if entry["rate_violations"] == 1:
                                    s_str = ""
                                else:
                                    s_str = "s"

                                rv_str = ", "+str(entry["rate_violations"])+" active rate violation"+s_str
                            else:
                                rv_str = ""
                            
                            if entry["blocked_until"] > time.time():
                                bli = time.time()-(int(entry["blocked_until"])-time.time())
                                bl_str = ", new announces allowed in "+pretty_date(int(bli))
                            else:
                                bl_str = ""

            
                            print(RNS.prettyhexrep(entry["hash"])+" last heard "+last_str+" ago, "+str(hour_rate)+" announces/hour in the last "+span_str+rv_str+bl_str)

                        except Exception as e:
                            print("Error while processing entry for "+RNS.prettyhexrep(entry["hash"]))
                            print(str(e))

                if destination_hash != None and displayed == 0:
                    print("No information available")
                    sys.exit(1)

    elif drop_queues:
        if remote_link:
            if not no_output:
                print(output_rst_str, end="")
                print("Dropping announce queues on remote instances not yet implemented")
            exit(255)

        print("Dropping announce queues on all interfaces...")
        reticulum.drop_announce_queues()
    
    elif drop:
        if remote_link:
            if not no_output:
                print(output_rst_str, end="")
                print("Dropping path on remote instances not yet implemented")
            exit(255)

        try:
            dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
            if len(destination_hexhash) != dest_len: raise ValueError("Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
            try: destination_hash = bytes.fromhex(destination_hexhash)
            except Exception as e: raise ValueError("Invalid destination entered. Check your input.")
        except Exception as e:
            print(str(e))
            sys.exit(1)

        if reticulum.drop_path(destination_hash): print("Dropped path to "+RNS.prettyhexrep(destination_hash))
        else:
            print("Unable to drop path to "+RNS.prettyhexrep(destination_hash)+". Does it exist?")
            sys.exit(1)

    elif drop_via:
        if remote_link:
            if not no_output:
                print(output_rst_str, end="")
                print("Dropping all paths via specific transport instance on remote instances yet not implemented")
            exit(255)

        try:
            dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
            if len(destination_hexhash) != dest_len: raise ValueError("Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
            try: destination_hash = bytes.fromhex(destination_hexhash)
            except Exception as e: raise ValueError("Invalid destination entered. Check your input.")
        except Exception as e:
            print(str(e))
            sys.exit(1)

        if reticulum.drop_all_via(destination_hash): print("Dropped all paths via "+RNS.prettyhexrep(destination_hash))
        else:
            print("Unable to drop paths via "+RNS.prettyhexrep(destination_hash)+". Does the transport instance exist?")
            sys.exit(1)

    else:
        if remote_link:
            if not no_output:
                print(output_rst_str, end="")
                print("Requesting paths on remote instances not implemented")
            exit(255)

        try:
            dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
            if len(destination_hexhash) != dest_len: raise ValueError("Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
            try: destination_hash = bytes.fromhex(destination_hexhash)
            except Exception as e: raise ValueError("Invalid destination entered. Check your input.")
        except Exception as e:
            print(str(e))
            sys.exit(1)
            
        if not RNS.Transport.has_path(destination_hash):
            RNS.Transport.request_path(destination_hash)
            print("Path to "+RNS.prettyhexrep(destination_hash)+" requested  ", end=" ")
            sys.stdout.flush()

        i = 0
        syms = "⢄⢂⢁⡁⡈⡐⡠"
        limit = time.time()+timeout
        while not RNS.Transport.has_path(destination_hash) and time.time()<limit:
            time.sleep(0.1)
            print(("\b\b"+syms[i]+" "), end="")
            sys.stdout.flush()
            i = (i+1)%len(syms)

        if RNS.Transport.has_path(destination_hash):
            hops = RNS.Transport.hops_to(destination_hash)
            next_hop_bytes = reticulum.get_next_hop(destination_hash)
            if next_hop_bytes == None:
                print("\r                                                       \rError: Invalid path data returned")
                sys.exit(1)
            else:
                next_hop = RNS.prettyhexrep(next_hop_bytes)
                next_hop_interface = reticulum.get_next_hop_if_name(destination_hash)

                if hops != 1: ms = "s"
                else: ms = ""

                print("\rPath found, destination "+RNS.prettyhexrep(destination_hash)+" is "+str(hops)+" hop"+ms+" away via "+next_hop+" on "+next_hop_interface)
        else:
            print("\r                                                       \rPath not found")
            sys.exit(1)


def main():
    try:
        parser = argparse.ArgumentParser(description="Reticulum Path Management Utility")
        parser.add_argument("--config", action="store", default=None, help="path to alternative Reticulum config directory", type=str)
        parser.add_argument("--version", action="version", version="rnpath {version}".format(version=__version__))
        parser.add_argument("-t", "--table", action="store_true", help="show all known paths", default=False)
        parser.add_argument("-m", "--max", action="store", metavar="hops", type=int, help="maximum hops to filter path table by", default=None)
        parser.add_argument("-r", "--rates", action="store_true", help="show announce rate info", default=False)
        parser.add_argument("-d", "--drop", action="store_true", help="remove the path to a destination", default=False)
        parser.add_argument("-D", "--drop-announces", action="store_true", help="drop all queued announces", default=False)
        parser.add_argument("-x", "--drop-via", action="store_true", help="drop all paths via specified transport instance", default=False)
        parser.add_argument("-w", action="store", metavar="seconds", type=float, help="timeout before giving up", default=RNS.Transport.PATH_REQUEST_TIMEOUT)
        parser.add_argument("-R", action="store", metavar="hash", help="transport identity hash of remote instance to manage", default=None, type=str)
        parser.add_argument("-i", action="store", metavar="path", help="path to identity used for remote management", default=None, type=str)
        parser.add_argument("-W", action="store", metavar="seconds", type=float, help="timeout before giving up on remote queries", default=RNS.Transport.PATH_REQUEST_TIMEOUT)
        parser.add_argument("-b", "--blackholed", action="store_true", help="list blackholed identities", default=False)
        parser.add_argument("-B", "--blackhole", action="store_true", help="blackhole identity", default=False)
        parser.add_argument("-U", "--unblackhole", action="store_true", help="unblackhole identity", default=False)
        parser.add_argument(      "--duration", action="store", type=float, help="duration of blackhole enforcement in hours", default=None)
        parser.add_argument(      "--reason", action="store", type=str, help="reason for blackholing identity", default=None)
        parser.add_argument("-p", "--blackholed-list", action="store_true", help="view published blackhole list for remote transport instance", default=False)
        parser.add_argument("-j", "--json", action="store_true", help="output in JSON format", default=False)
        parser.add_argument("destination", nargs="?", default=None, help="hexadecimal hash of the destination", type=str)
        parser.add_argument("list_filter", nargs="?", default=None, help="filter for remote blackhole list view", type=str)
        parser.add_argument('-v', '--verbose', action='count', default=0)
        
        args = parser.parse_args()

        if args.config: configarg = args.config
        else: configarg = None

        if not args.drop_announces and not args.table and not args.rates and not args.destination and not args.drop_via and not args.blackholed:
            print("")
            parser.print_help()
            print("")
        else:
            program_setup(configdir = configarg, table = args.table, rates = args.rates, drop = args.drop, destination_hexhash = args.destination,
                          verbosity = args.verbose, timeout = args.w, drop_queues = args.drop_announces, drop_via = args.drop_via, max_hops = args.max,
                          remote=args.R, management_identity=args.i, remote_timeout=args.W, blackholed=args.blackholed, blackhole=args.blackhole,
                          unblackhole=args.unblackhole, blackhole_duration=args.duration, blackhole_reason=args.reason, remote_blackhole_list=args.blackholed_list,
                          remote_blackhole_list_filter=args.list_filter, json=args.json)

            sys.exit(0)

    except KeyboardInterrupt:
        print("")
        exit()

def pretty_date(time=False):
    from datetime import datetime
    now = datetime.now()
    if type(time) is int: diff = now - datetime.fromtimestamp(time)
    elif isinstance(time,datetime): diff = now - time
    elif not time: diff = now - now
    second_diff = diff.seconds
    day_diff = diff.days
    if day_diff < 0: return ''
    if day_diff == 0:
        if second_diff < 10: return str(second_diff) + " seconds"
        if second_diff < 60: return str(second_diff) + " seconds"
        if second_diff < 120: return "1 minute"
        if second_diff < 3600: return str(int(second_diff / 60)) + " minutes"
        if second_diff < 7200: return "an hour"
        if second_diff < 86400: return str(int(second_diff / 3600)) + " hours"
    if day_diff == 1: return "1 day"
    if day_diff < 7: return str(day_diff) + " days"
    if day_diff < 31: return str(int(day_diff / 7)) + " weeks"
    if day_diff < 365: return str(int(day_diff / 30)) + " months"
    return str(int(day_diff / 365)) + " years"

if __name__ == "__main__": main()