Mercurial > repos > bcclaywell > argo_navis
diff bin/format_beastfile.py @ 0:d67268158946 draft
planemo upload commit a3f181f5f126803c654b3a66dd4e83a48f7e203b
author | bcclaywell |
---|---|
date | Mon, 12 Oct 2015 17:43:33 -0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/bin/format_beastfile.py Mon Oct 12 17:43:33 2015 -0400 @@ -0,0 +1,212 @@ +#!/usr/bin/env python +""" +Formats a given BEAST XML file (possibly all ready to run) and respecifies the information needed to run the +classic Discrete trait. + +Some things that would be nice: +* specify output files/formats (could let you run from root instead of the dir) +""" + +from Bio import SeqIO +import xml.etree.ElementTree as ET +import argparse +import copy +import csv + + +def clear_children(node): + "Element.remove doesn't seem to work the way it's supposed to, so we're doing this" + node_attrib = copy.copy(node.attrib) + node.clear() + node.attrib = node_attrib + + +def set_alignment(xmldoc, alignment): + """This function replaces the alignment data in xmldoc with that from sequences in alignment.""" + aln_node = xmldoc.find('data') + # First clear out the old alignment sequences + clear_children(aln_node) + print "seqs" + for seq in aln_node: + print seq + # Next, construct and throw in the new sequence nodes + for seq_record in alignment: + seqid = seq_record.name + ET.SubElement(aln_node, 'sequence', + attrib=dict(id="seq_" + seqid, + taxon=seqid, + totalcount="4", + value=str(seq_record.seq))) + + +def get_data_id(xmldoc): + """The data set will have a given name, assigned by BEAUti, typically based on the named of the data file + loaded into it. This name gets referred to in a number of places (presumably so there can be a number of + partitions/datasets in an analysis), and is needed by other bits of code that do their thing.""" + return xmldoc.find(".//data[@name='alignment'][@id]").attrib['id'] + + +def default_deme_getter(metarow): + """A default function for getting the deme data from a given metadata row. Specifically defaults to 'deme' + first, then to 'community' next. Returns none if it doesn't find either.""" + return metarow.get('deme') or metarow.get('community') + + +def set_deme(xmldoc, metadata, deme_getter=default_deme_getter): + """Sets the deme information of the xmldoc based on metadata, and using the deme_getter (by default the + `default_deme_getter` function above.""" + trait_node = xmldoc.iter('traitSet').next() + trait_string = ",\n".join([row['sequence'] + "=" + deme_getter(row) for row in metadata]) + trait_node.text = trait_string + + +def build_date_node(date_spec, data_id): + """Builds a node of date traits, given the date_spec string which is the actual string representation of + the sequence -> date mapping. Has to create a `taxa` subnode, and a `data` subnode of that, which points + to the data set in question via `idref`.""" + date_node = ET.Element('trait', + id='dateTrait.t:' + data_id, + spec='beast.evolution.tree.TraitSet', + traitname='date') + date_node.text = date_spec + taxa_node = ET.SubElement(date_node, 'taxa', + id='TaxonSet.' + data_id, + spec='TaxonSet') + _ = ET.SubElement(taxa_node, 'data', + idref=data_id, + name="alignment") + return date_node + + +def set_date(xmldoc, metadata, date_attr='date'): + """Builds a dateTrait node via `build_date_node` above, and inserts into the `.//state/tree` node. + However, this `tree` node already contains a `taxonset` node which has a `data` node, and this + `taxonset` node has the same id as the `taxa` node in the the date `trait` node. As such, the node that + _was_ present must be removed, so that we don't get a duplicate id error. Instead, we replace the old + taxonset node with one which has an `idref` pointing to the `taxa` node inside the `trait` node. This is + rather convoluted, and I'm not possible that some file with multiple datasets wouldn't break on this, but + this described strategy seems to work for now.""" + # First get our tree node; we'll be adding the date data to this + tree_node = xmldoc.find('.//state/tree') + # Construct our trait string, just as we do for `set_deme` + trait_string = ",\n".join([row['sequence'] + "=" + row[date_attr] for row in metadata]) + # Build the date trait node, and carry out all the weird mucking to get the new `taxonset` node in, as + # described in the docstring + data_id = get_data_id(xmldoc) + date_node = build_date_node(trait_string, data_id) + old_taxonset = tree_node.find("./taxonset") + tree_node.insert(0, date_node) + tree_node.remove(old_taxonset) + new_taxonset = ET.SubElement(tree_node, "taxonset", idref="TaxonSet."+data_id) + + +def get_current_interval(xmldoc): + run_node = xmldoc.find('run') + loggers = run_node.findall('logger') + intervals = list(set([int(l.get('logEvery')) for l in loggers if l.get('id') != 'screenlog'])) + if len(intervals) > 1: + raise "Cannot get an interval for this xml doc; there are multiple such values" + return intervals[0] + + +def set_mcmc(xmldoc, samples, sampling_interval): + "Sets the MCMC chain settings (how often to log, how long to run, etc" + run_node = xmldoc.find('run') + # XXX Should really make it so that you only have to specify _one_, and it will find current value of + # other so that chain length doesn't break. + chain_length = samples * sampling_interval + 1 + run_node.set('chainLength', str(chain_length)) + loggers = run_node.findall('logger') + for logger in loggers: + logevery = sampling_interval * 10 if logger.get('id') == 'screenlog' else sampling_interval + logger.set('logEvery', str(logevery)) + + +def normalize_filenames(xmldoc, logger_filename="posterior.log", treefile_filename="posterior.trait.trees"): + run_node = xmldoc.find('run') + logfile_node = run_node.find('logger[@id="tracelog"]') + treefile_node = run_node.find('logger[@id="treeWithTraitLogger.deme"]') + logfile_node.set('fileName', logger_filename) + treefile_node.set('fileName', treefile_filename) + + +def set_deme_count(xmldoc, metadata, deme_getter=default_deme_getter): + "Updates the model specs based onthe number of demes in the data set." + demes = list(set(map(deme_getter, metadata))) + demes.sort() + deme_count = len(demes) + mig_dim = (deme_count - 1) * deme_count / 2 + for xpath in ['.//parameter[@id="relativeGeoRates.s:deme"]', './/stateNode[@id="rateIndicator.s:deme"]']: + xmldoc.find(xpath).set('dimension', str(mig_dim)) + code_map = map(lambda ix: ix[1] + "=" + str(ix[0]), enumerate(demes)) + code_map = ",".join(code_map) + ",? = " + " ".join(map(str, range(deme_count))) + " " + user_data_type_node = xmldoc.find('.//userDataType') + user_data_type_node.set('codeMap', code_map) + user_data_type_node.set('states', str(deme_count)) + trait_frequencies_param = xmldoc.find('.//frequencies/parameter[@id="traitfrequencies.s:deme"]') + trait_frequencies_param.set('dimension', str(deme_count)) + trait_frequencies_param.text = str(1.0/deme_count) + + + +def get_args(): + def int_or_floatify(string): + return int(float(string)) + parser = argparse.ArgumentParser() + parser.add_argument('template', type=argparse.FileType('r'), + help="""A template BEAST XML (presumably created by Beauti) ready insertion of an alignment and + discrete trait.""") + parser.add_argument('-a', '--alignment', + help="Replace alignment in beast file with this alignment; Fasta format.") + parser.add_argument('-m', '--metadata', type=argparse.FileType('r'), + help="Should contain 'community' column referencing the deme.") + parser.add_argument('-s', '--samples', type=int_or_floatify, + help="Number of samples in output log file(s).") + parser.add_argument('-d', '--deme-col', + help="""Specifies the deme column for metadata; defaults to deme or community (whichever is present) + if not specified.""") + parser.add_argument('-D', '--date-col', + help="If specified, will add a date specification to the output BEAST XML file.") + parser.add_argument('-i', '--sampling-interval', type=int_or_floatify, + help="""Number of chain states to simulate between successive states samples for logfiles. The + total chain length is therefor samples * sampling_interval.""") + parser.add_argument('beastfile', type=argparse.FileType('w'), + help="Output BEAST XML file.") + return parser.parse_args() + + +def main(args): + # Read in old data + xmldoc = ET.parse(args.template) + + # Modify the data sets + if args.alignment: + alignment = SeqIO.parse(args.alignment, 'fasta') + set_alignment(xmldoc, alignment) + if args.metadata: + metadata = list(csv.DictReader(args.metadata)) + # Set the deme getter + deme_getter = lambda row: row[args.deme_col] if args.deme_col else default_deme_getter(row) + set_deme(xmldoc, metadata, deme_getter) + # _could_ do something smart here where we look at which sequences in the XML file traitset that match + # alignment passed in if _only_ alignment is passed in. Probably not worth it though... + set_deme_count(xmldoc, metadata, deme_getter) + if args.date_col: + set_date(xmldoc, metadata, args.date_col) + + if args.samples or args.sampling_interval: + interval = args.sampling_interval if args.sampling_interval else get_current_interval(xmldoc) + set_mcmc(xmldoc, args.samples, interval) + + # Make sure that we always have the same file names out. These are specified as defaults of the function, + # but could be customized here or through the cl args if needed. + normalize_filenames(xmldoc) + + # Write the output + xmldoc.write(args.beastfile) + + +if __name__ == '__main__': + main(get_args()) + +