# Extract taxonomy ids for each database sequences for UniRef100 from ColabFold uniref30_2302_db_seq_h lines
# that look like
#
# UniRef100_A0A5A9P0L4 peptidylprolyl isomerase n=1 Tax=Triplophysa tibetana TaxID=1572043 RepID=A0A5A9P0L4_9TELE

def create_uniref100_to_taxonomy_id_map(uniref100_headers_path, tax_map_npz_path):
    # First make text file mapping UniRef100 ids to taxonomy ids.
    from os.path import splitext
    tax_map_text_path = splitext(tax_map_npz_path)[0] + '.txt'
#    count = uniref100_headers_to_taxonomy_map(uniref100_headers_path, tax_map_text_path)
    count = 350950053

    # Then make numpy two arrays of SHA1 hashed uniref100 ids and taxonomy integer ids
    # for memory efficient loading and lookup with numpy searchsorted().
    hashes, tax_ids = uniref100_to_taxonomy_npz_map(tax_map_text_path, count)
    from numpy import savez
    savez(tax_map_npz_path, uid_hashes = hashes, tax_ids = tax_ids)

def uniref100_headers_to_taxonomy_map(db_seq_h_path, tax_map_text_path):
    count = 0
    with open(db_seq_h_path, 'r') as h:
        with open(tax_map_text_path, 'w') as tm:
            while line := h.readline():
                if line.startswith('\0'):
                    line = line[1:]
                if line.startswith('UniRef100_'):
                    uniref_id = line.split(maxsplit=1)[0][10:]
                    i = line.find('TaxID=')
                    if i == -1:
                        tax_id = -1
                    else:
                        tax_id = int(line[i+6:].split(maxsplit=1)[0])
                    tm.write(f'{uniref_id}\t{tax_id}\n')
                    count += 1
                else:
                    print(f'line did not start with UniRef100_: {line}')
    return count

def uniref100_to_taxonomy_npz_map(tax_map_text_path, count):
    from numpy import empty, uint64, int32
    hashes = empty((count,), uint64)
    taxids = empty((count,), int32)
    with open(tax_map_text_path, 'r') as tax:
        i = 0
        while line := tax.readline():
            uid, taxid = line.split()
            hashes[i] = uid_hash(uid)
            taxids[i] = int(taxid)
            i += 1

    # Check if there are any hash collisions:
    from numpy import unique
    nc = len(hashes) - len(unique(hashes))
    if nc > 0:
        # There is about a 1 in 1000 collision chance with 250 million 64-bit hashes.
        print(f'There were {nc} has collisions for {count} UniRef100 ids.  Could not make taxonomy map.')
        import sys
        sys.exit(1)

    # Sort hashes since lookup will use binary search.
    from numpy import argsort
    i = argsort(hashes)
    h = hashes[i]
    t = taxids[i]

    return h, t

from hashlib import sha1
def uid_hash(uid):
    # The Python hash() function is 10x faster but it is deliberately
    # not reproducible across Python invocations as it uses salt.
    return int(sha1(uid.encode("utf-8")).hexdigest()[:16], 16)

from sys import argv
uniref100_headers_path, tax_map_path = argv[1:]
create_uniref100_to_taxonomy_id_map(uniref100_headers_path, tax_map_path)

# Example command:
#   python3 uniref100_taxonomy_ids.py uniref30_2302_db_seq_h uniref100_ncbi_taxonomy_ids.npz
