def add_taxonomy_to_a3m(a3m_path, tax_id_map):
    lines = []
    with open(a3m_path, 'r') as a3m:
        while line := a3m.readline():
            if line.startswith('>UniRef100_'):
                f0 = line.split()[0]
                uid = f0.split('_')[1]
                tax_id = tax_id_map.find(uid)
                if tax_id is not None:
                    line = f0 + f' TX={tax_id}\n'
            lines.append(line)

    with open(a3m_path, 'w') as a3m:
        a3m.write(''.join(lines))

class UniprotToTaxonomyMap:
    def __init__(self, npz_path):
        from numpy import load
        taxids = load(npz_path)
        self._uid_hashes = taxids['uid_hashes']
        self._tax_ids = taxids['tax_ids']

    def find(self, uniprot_id):
        h = self._hash_uniprot_id(uniprot_id)
        # searchsorted becomes 100,000 times slower if hash is not cast to uint64 and key is not in array.
        from numpy import uint64
        h = uint64(h)
        hashes = self._uid_hashes
        from numpy import searchsorted
        i = searchsorted(hashes, h)
        if i < len(hashes) and hashes[i] == h:
            return self._tax_ids[i]
        print('no taxonomy for', uniprot_id)
        return None

    def _hash_uniprot_id(self, uniprot_id):
        from hashlib import sha1
        return int(sha1(uniprot_id.encode("utf-8")).hexdigest()[:16], 16)

from sys import argv
tax_map_npz_path = argv[1]
tax_id_map = UniprotToTaxonomyMap(tax_map_npz_path)

for a3m_path in argv[2:]:
    add_taxonomy_to_a3m(a3m_path, tax_id_map)
