Mercurial > repos > bcclaywell > argo_navis
view bin/deme_downsample.py @ 0:d67268158946 draft
planemo upload commit a3f181f5f126803c654b3a66dd4e83a48f7e203b
author | bcclaywell |
---|---|
date | Mon, 12 Oct 2015 17:43:33 -0400 |
parents | |
children |
line wrap: on
line source
#!/usr/bin/env python """Given the clustering results of a run of alnclst, this tool takes those results and find a single representative sequence for each cluster. In particular, it chooses the cluster representative closest to the cluster center. """ import argparse import random import alnclst import csv from Bio import SeqIO settings = dict(consensus_threshold=None, batches=2, max_iters=100) def kmeans_runner(seqrecords, k): "Runs kmeans on seqrecords and picks representatives from each cluster, returning their names in a list." # Define clustering function we'll run batches number of times def clustering(): return alnclst.KMeansClsutering(seqrecords, k, settings['consensus_threshold'], max_iters=settings['max_iters']) # Run the batches, and pick the one with the best convergence _, clusts = min((c.average_distance(), c) for c in (clustering() for i in xrange(settings['batches']))) # Pick the best representative for every cluster, and thow in clust_reps dict clust_reps = dict() for cluster_id, sequence, distance in clusts.mapping_iterator(): current = (distance, sequence) clst_min = clust_reps.get(cluster_id, current) if current <= clst_min: clust_reps[cluster_id] = current return [seqname for (_, (_, seqname)) in clust_reps.iteritems()] def random_runner(seqnames, k): "Randomly samples k seqnames from seqnames" if len(seqnames) < k: return seqnames else: return random.sample(seqnames, k) def make_deme_map(deme_spec, deme_col): "Turns deme metadata into a map of deme -> sequence names" deme_map = dict() for row in deme_spec: try: deme = row[deme_col] except KeyError: raise KeyError, "Make sure to specify a --deme-col that's actually in the deme file" seqname = row['sequence'] try: deme_map[deme].append(seqname) except KeyError: deme_map[deme] = [seqname] return deme_map def get_args(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('alignment', type=argparse.FileType('r'), help="Alignment FASTA file") parser.add_argument('demes', type=argparse.FileType('r'), help="CSV metadata specifying deme info") parser.add_argument('-c', '--deme-col', default='deme', help="Column specifying 'deme' argument in demes spec") parser.add_argument('-k', help="Maximum number of representatives for each deme", type=int) parser.add_argument('-s', '--seed', help="Random seed for reproducibility") parser.add_argument('-m', '--method', choices=('random', 'kmeans'), help="Which downsampling method should be used") parser.add_argument('out_alignment', type=argparse.FileType('w'), help="Downsampled alignment output") parser.add_argument('out_demes', type=argparse.FileType('w'), help="Downsampled metadata output") return parser.parse_args() def main(): args = get_args() # Set random seed if needed if args.seed: random.seed(args.seed) # Create a lit of seqrecords to make things easier for ourselves seqrecords = SeqIO.to_dict(SeqIO.parse(args.alignment, 'fasta')) demes = list(csv.DictReader(args.demes)) # Turn our metadata into a map of deme -> seqnames deme_map = make_deme_map(demes, args.deme_col) # Run the specified downsampling method for each deme, and gather kept representatives rep_seqnames = [] for deme, seqnames in deme_map.iteritems(): # this makes it safe to have a csv file with "extra" stuff seqnames = [n for n in seqnames if n in seqrecords.keys()] if args.method == 'random': deme_rep_seqnames = random_runner(seqnames, args.k) else: deme_sequences = [seqrecords[n] for n in seqnames] deme_rep_seqnames = kmeans_runner(deme_sequences, args.k) rep_seqnames += deme_rep_seqnames # Filter down the actual data based on representative names deme_rep_seqs = [seqrecords[n] for n in rep_seqnames] deme_rep_meta = [r for r in demes if r['sequence'] in rep_seqnames] out_demes = csv.DictWriter(args.out_demes, deme_rep_meta[1].keys()) out_demes.writeheader() out_demes.writerows(deme_rep_meta) SeqIO.write(deme_rep_seqs, args.out_alignment, 'fasta') for fh in [args.alignment, args.demes, args.out_alignment, args.out_demes]: fh.close() if __name__ == '__main__': main()