# Read .a3m files one file per protein and use the TX=123 sequence headers
# giving taxonomy id to pair sequences.
# Write out one fasta format file with concatenated sequences.

def paired_fasta(a3m_paths):
    
    tax_seqs = []
    ref_seqs = []
    for a3m_path in a3m_paths:
        with open(a3m_path, 'r') as a3m:
            lines = a3m.readlines()
            sequences = [line.strip() for line in lines[1::2]]
            ref_seqs.append(sequences[0])
            sequence_ids = [header[1:].split(maxsplit=1)[0] for header in lines[::2]]
            taxonomy_ids = []
            for header in lines[::2]:
                i = header.find('TX=')
                taxonomy_id = int(header[i+3:].split(maxsplit=1)[0]) if i >= 0 else None
                taxonomy_ids.append(taxonomy_id)
            tax_seqs.append((taxonomy_ids, sequence_ids, sequences))

    tax_counts = {}
    for tax_ids, seq_ids, seqs in tax_seqs:
        for tax_id in set(tax_ids):
            if tax_id in tax_counts:
                tax_counts[tax_id] += 1
            else:
                tax_counts[tax_id] = 1

    paired_tax_ids = [tax_id for tax_id, count in tax_counts.items()
                      if count > 1 and tax_id is not None]
    paired_tax_ids.sort()  # mmseqs pairaln outputs in taxonomy id order.
    paired_tax_id_set = set(paired_tax_ids)

    n = len(a3m_paths)
    paired_seqs = {tax_id:[(None,None)]*n for tax_id in paired_tax_ids}
    for j, (tax_ids, seq_ids, seqs) in enumerate(tax_seqs):
        for si,(tax_id,seq_id,seq) in enumerate(zip(tax_ids, seq_ids, seqs)):
            if tax_id in paired_tax_id_set:
                if paired_seqs[tax_id][j][0] is None:
                    paired_seqs[tax_id][j] = (seq_id,seq)

    lines = ['>reference',
             "".join(ref_seqs)]
    import re
    for tax_id in paired_tax_ids:
        seqs = []
        seq_ids = []
        for j,(seq_id,seq) in enumerate(paired_seqs[tax_id]):
            if seq is None:
                seq_id = 'none'
                seq = '-' * len(ref_seqs[j])
            else:
                seq = seq = re.sub('[a-z]', '', seq)  # remove lower case insertions.
            seqs.append(seq)
            seq_ids.append(seq_id)
        lines.append(f'>{" ".join(seq_ids)} TX={tax_id}')
        lines.append(''.join(seqs))
    return lines

from sys import argv, stdout
lines = paired_fasta(argv[1:])
stdout.write('\n'.join(lines))
