diff bin/deme_downsample.py @ 0:d67268158946 draft

planemo upload commit a3f181f5f126803c654b3a66dd4e83a48f7e203b
author bcclaywell
date Mon, 12 Oct 2015 17:43:33 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/bin/deme_downsample.py	Mon Oct 12 17:43:33 2015 -0400
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+"""Given the clustering results of a run of alnclst, this tool takes those results and find a single
+representative sequence for each cluster. In particular, it chooses the cluster representative closest to the
+cluster center.
+"""
+
+import argparse
+import random
+import alnclst
+import csv
+from Bio import SeqIO
+
+
+settings = dict(consensus_threshold=None, batches=2, max_iters=100)
+
+
+def kmeans_runner(seqrecords, k):
+    "Runs kmeans on seqrecords and picks representatives from each cluster, returning their names in a list."
+    # Define clustering function we'll run batches number of times
+    def clustering():
+        return alnclst.KMeansClsutering(seqrecords, k, settings['consensus_threshold'], max_iters=settings['max_iters'])
+    # Run the batches, and pick the one with the best convergence
+    _, clusts = min((c.average_distance(), c) for c in (clustering() for i in
+             xrange(settings['batches'])))
+    # Pick the best representative for every cluster, and thow in clust_reps dict
+    clust_reps = dict()
+    for cluster_id, sequence, distance in clusts.mapping_iterator():
+        current = (distance, sequence)
+        clst_min = clust_reps.get(cluster_id, current)
+        if current <= clst_min:
+            clust_reps[cluster_id] = current
+    return [seqname for (_, (_, seqname)) in clust_reps.iteritems()]
+
+
+def random_runner(seqnames, k):
+    "Randomly samples k seqnames from seqnames"
+    if len(seqnames) < k:
+        return seqnames
+    else:
+        return random.sample(seqnames, k)
+
+
+def make_deme_map(deme_spec, deme_col):
+    "Turns deme metadata into a map of deme -> sequence names"
+    deme_map = dict()
+    for row in deme_spec:
+        try:
+            deme = row[deme_col]
+        except KeyError:
+            raise KeyError, "Make sure to specify a --deme-col that's actually in the deme file"
+        seqname = row['sequence']
+        try:
+            deme_map[deme].append(seqname)
+        except KeyError:
+            deme_map[deme] = [seqname]
+    return deme_map
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument('alignment', type=argparse.FileType('r'), help="Alignment FASTA file")
+    parser.add_argument('demes', type=argparse.FileType('r'), help="CSV metadata specifying deme info")
+    parser.add_argument('-c', '--deme-col', default='deme', help="Column specifying 'deme' argument in demes spec")
+    parser.add_argument('-k', help="Maximum number of representatives for each deme", type=int)
+    parser.add_argument('-s', '--seed', help="Random seed for reproducibility")
+    parser.add_argument('-m', '--method', choices=('random', 'kmeans'),
+            help="Which downsampling method should be used")
+    parser.add_argument('out_alignment', type=argparse.FileType('w'), help="Downsampled alignment output")
+    parser.add_argument('out_demes', type=argparse.FileType('w'), help="Downsampled metadata output")
+    return parser.parse_args()
+
+
+def main():
+    args = get_args()
+    # Set random seed if needed
+    if args.seed:
+        random.seed(args.seed)
+
+    # Create a lit of seqrecords to make things easier for ourselves
+    seqrecords = SeqIO.to_dict(SeqIO.parse(args.alignment, 'fasta'))
+    demes = list(csv.DictReader(args.demes))
+
+    # Turn our metadata into a map of deme -> seqnames
+    deme_map = make_deme_map(demes, args.deme_col)
+
+    # Run the specified downsampling method for each deme, and gather kept representatives
+    rep_seqnames = []
+    for deme, seqnames in deme_map.iteritems():
+        # this makes it safe to have a csv file with "extra" stuff
+        seqnames = [n for n in seqnames if n in seqrecords.keys()]
+        if args.method == 'random':
+            deme_rep_seqnames = random_runner(seqnames, args.k)
+        else:
+            deme_sequences = [seqrecords[n] for n in seqnames]
+            deme_rep_seqnames = kmeans_runner(deme_sequences, args.k)
+        rep_seqnames += deme_rep_seqnames
+
+    # Filter down the actual data based on representative names
+    deme_rep_seqs = [seqrecords[n] for n in rep_seqnames]
+    deme_rep_meta = [r for r in demes if r['sequence'] in rep_seqnames]
+
+    out_demes = csv.DictWriter(args.out_demes, deme_rep_meta[1].keys())
+    out_demes.writeheader()
+    out_demes.writerows(deme_rep_meta)
+
+    SeqIO.write(deme_rep_seqs, args.out_alignment, 'fasta')
+
+    for fh in [args.alignment, args.demes, args.out_alignment, args.out_demes]:
+        fh.close()
+
+
+if __name__ == '__main__':
+    main()
+
+