# Read .a3m files one file per protein and use the TX=123 sequence headers giving taxonomy id to pair sequences. # Wrote out Boltz format .csv files for each of the proteins. def boltz_pairing(yaml_dir, msa_dir): from os import listdir yaml_files = [filename for filename in listdir(yaml_dir) if filename.endswith('.yaml')] for yaml_file in yaml_files: name1, name2 = yaml_file.split('.')[:2] if name1 == name2: write_unpaired_csv(msa_dir, name1) else: write_paired_csv(msa_dir, name1, name2) def write_paired_csv(msa_dir, name1, name2, max_paired = 8192, max_seqs = 16384): from os.path import join pre_paired_a3m_paths = [join(msa_dir, f'{name}.pre_paired.a3m') for name in (name1, name2)] unpaired_a3m_paths = [join(msa_dir, f'{name}.a3m') for name in (name1, name2)] nres = 0 tax_seqs = [] for a3m_path in pre_paired_a3m_paths: with open(a3m_path, 'r') as a3m: lines = a3m.readlines() sequences = lines[1::2] nres += len(sequences[1])-1 taxonomy_ids = [] for header in lines[::2]: i = header.find('TX=') taxonomy_id = int(header[i+3:].split(maxsplit=1)[0]) if i >= 0 else None taxonomy_ids.append(taxonomy_id) tax_seqs.append((taxonomy_ids, sequences)) # Reduce number paired to reduce out of memory errors. lim_paired = max_paired_for_memory(nres, memory_gbytes = 24) max_paired = min(lim_paired, max_paired) max_seqs = min(2*max_paired, max_seqs) # Also cap the number of unpaired sequences to the number as paired tax_counts = {} for tax_ids, seqs in tax_seqs: for tax_id in set(tax_ids): if tax_id in tax_counts: tax_counts[tax_id] += 1 else: tax_counts[tax_id] = 1 paired_tax_ids = [tax_id for tax_id, count in tax_counts.items() if count > 1 and tax_id is not None] paired_tax_ids.sort() # mmseqs pairaln outputs in taxonomy id order. pair_key = {tax_id:i+1 for i,tax_id in enumerate(paired_tax_ids)} for j, (tax_ids, seqs) in enumerate(tax_seqs): paired_seqs = {} for si,(tax_id,seq) in enumerate(zip(tax_ids, seqs)): pk = pair_key.get(tax_id) if si > 0 else 0 if pk and pk not in paired_seqs: paired_seqs[pk] = seq pseq = list(paired_seqs.items()) pseq.sort() pseq = pseq[:max_paired-1] unpaired_a3m_path = unpaired_a3m_paths[j] with open(unpaired_a3m_path, 'r') as a3m: lines = a3m.readlines() unpaired_seqs = lines[1::2] if pseq: unpaired_seqs = unpaired_seqs[1:] # Drop initial query sequence max_unpaired = max_seqs - (len(pseq) + 1) unpaired_seqs = unpaired_seqs[:max_unpaired] pseq.insert(0, (0, lines[1])) # Paired seqs start with query with open(join(msa_dir, f'{name1}_{name2}_{j}.csv'), 'w') as csv: csv.write('key,sequence\n') for key, seq in pseq: csv.write(f'{key},{seq}') for seq in unpaired_seqs: csv.write(f'-1,{seq}') def write_unpaired_csv(msa_dir, name, max_seqs = 16384): from os.path import join unpaired_a3m_path = join(msa_dir, f'{name}.a3m') with open(unpaired_a3m_path, 'r') as a3m: lines = a3m.readlines() unpaired_seqs = lines[1::2] nres = len(unpaired_seqs[0]) lim_seqs = max_paired_for_memory(2*nres, memory_gbytes = 24) unpaired_seqs = unpaired_seqs[:min(lim_seqs,max_seqs)] with open(join(msa_dir, f'{name}.csv'), 'w') as csv: csv.write('key,sequence\n') for seq in unpaired_seqs: csv.write(f'-1,{seq}') def max_paired_for_memory(nres, memory_gbytes = 24, min_paired = 512): return 16384 # Unlimit MSA size for take6 test. # Number of pairs estimated to use specified memory in Gbytes n_paired = int((memory_gbytes*(1024**3) - 19151*nres*nres) / (2602*nres)) n_paired = max(min_paired, n_paired) # Always allow at least 512 pairs return n_paired from sys import argv yaml_dir = argv[1] msa_dir = argv[2] boltz_pairing(yaml_dir, msa_dir)