# Create MSA .csv files for Boltz with specified number of paired and unpaired sequences
# for testing memory use.

def make_msas(yaml_path, n_paired, n_unpaired):
    from random import randint, random
    aa = 'ACDEFGHIKLMNPQRSTVWY'
    i = 0
    with open(yaml_path) as f:
        for line in f.readlines():
            if line.strip().startswith('sequence:'):
                seq = line.split(':')[1].strip()
                with open(f'msa{i}.csv', 'w') as msa:
                    msa.write('key,sequence\n')
                    msa.write(f'0,{seq}\n')
                    for j in range(1,n_paired):
                        rand_seq = ''.join([aa[randint(0,19)] if random() < 0.05 else c for c in seq])
                        msa.write(f'{j},{rand_seq}\n')
                    for j in range(n_unpaired):
                        rand_seq = ''.join([aa[randint(0,19)] if random() < 0.05 else c for c in seq])
                        msa.write(f'-1,{rand_seq}\n')
                i += 1

from sys import argv
yaml_path = argv[1]
n_paired = int(argv[2])
n_unpaired = int(argv[3]) if len(argv) >= 4 else 0
make_msas(yaml_path, n_paired, n_unpaired)
