view fasta_extract_utils.py @ 2:4dc6890db112 draft

Uploaded
author greg
date Sun, 10 Jan 2016 13:52:44 -0500
parents bc3f2a5c7b53
children
line wrap: on
line source

import cPickle
import numpy
import os
import string
import sys

COMPLEMENT = lambda s: s.translate(string.maketrans('ATCGatcgNnXx', 'TAGCtagcNnXx'))
MAGIC = "@flattened@"


def ext_is_flat(ext):
    fh = open(ext)
    t = fh.read(len(MAGIC))
    fh.close()
    return MAGIC == t


def is_up_to_date(a, b):
    return os.path.exists(a) and os.stat(a).st_mtime >= os.stat(b).st_mtime


class FastaRecord(object):
    __slots__ = ('fh', 'start', 'stop')
    ext = ".flat"
    idx = ".gdx"

    def __init__(self, fh, start, stop):
        self.fh = fh
        self.stop = stop
        self.start = start

    @property
    def __array_interface__(self):
        return {'shape': (len(self), ),
                'typestr': '|S1',
                'version': 3,
                'data': buffer(self)}

    def __getitem__(self, islice):
        fh = self.fh
        fh.seek(self.start)
        if isinstance(islice, (int, long)):
            if islice < 0:
                if -islice > self.stop - self.start:
                    raise IndexError
                fh.seek(self.stop + islice)
            else:
                fh.seek(self.start + islice)
            return fh.read(1)
        if islice.start in (0, None) and islice.stop in (None, sys.maxint):
            if islice.step in (1, None):
                return fh.read(self.stop - self.start)
            return fh.read(self.stop - self.start)[::islice.step]
        istart, istop = self._adjust_slice(islice)
        if istart is None:
            return ""
        l = istop - istart
        if l == 0:
            return ""
        fh.seek(istart)
        if islice.step in (1, None):
            return fh.read(l)
        return fh.read(l)[::islice.step]

    def __len__(self):
        return self.stop - self.start

    def __repr__(self):
        return "%s('%s', %i..%i)" % (self.__class__.__name__,
                                     self.fh.name,
                                     self.start,
                                     self.stop)

    def __str__(self):
        return self[:]

    def _adjust_slice(self, islice):
        if islice.start is not None and islice.start < 0:
            istart = self.stop + islice.start
        else:
            if islice.start is None:
                istart = self.start
            else:
                istart = self.start + islice.start
        if islice.stop is not None and islice.stop < 0:
            istop = self.stop + islice.stop
        else:
            istop = islice.stop is None and self.stop or (self.start + islice.stop)
        # This will give empty string.
        if istart > self.stop:
            return self.stop, self.stop
        if istart < self.start:
            istart = self.start
        if istop < self.start:
            istop = self.start
        elif istop > self.stop:
            istop = self.stop
        return istart, istop

    @classmethod
    def copy_inplace(klass, flat_name, fasta_name):
        """
        Overwrite the .fasta file with the .fasta.flat file and save
        something in the .flat as a place-holder.
        """
        os.rename(flat_name, fasta_name)
        # Still need the flattened file to show it's current.
        flatfh = open(fasta_name + klass.ext, 'wb')
        flatfh.write(MAGIC)
        flatfh.close()

    @classmethod
    def is_current(klass, fasta_name):
        utd = is_up_to_date(fasta_name + klass.idx, fasta_name)
        if not utd:
            return False
        return is_up_to_date(fasta_name + klass.ext, fasta_name)

    @classmethod
    def modify_flat(klass, flat_file):
        return open(flat_file, 'rb')

    @classmethod
    def prepare(klass, fasta_obj, seqinfo_generator, flatten_inplace):
        """
        Returns the __getitem__'able index. and the thing from which to get the seqs.
        """
        f = fasta_obj.fasta_name
        if klass.is_current(f):
            fh = open(f + klass.idx, 'rb')
            idx = cPickle.load(fh)
            fh.close()
            if flatten_inplace or ext_is_flat(f + klass.ext):
                flat = klass.modify_flat(f)
            else:
                flat = klass.modify_flat(f + klass.ext)
            if flatten_inplace and not ext_is_flat(f + klass.ext):
                del flat
            else:
                return idx, flat
        idx = {}
        flatfh = open(f + klass.ext, 'wb')
        for i, (seqid, seq) in enumerate(seqinfo_generator):
            if flatten_inplace:
                if i == 0:
                    flatfh.write('>%s\n' % seqid)
                else:
                    flatfh.write('\n>%s\n' % seqid)
            start = flatfh.tell()
            flatfh.write(seq)
            stop = flatfh.tell()
            idx[seqid] = (start, stop)
        flatfh.close()
        if flatten_inplace:
            klass.copy_inplace(flatfh.name, f)
            fh = open(f + klass.idx, 'wb')
            cPickle.dump(idx, fh, -1)
            fh.close()
            return idx, klass.modify_flat(f)
        fh = open(f + klass.idx, 'wb')
        cPickle.dump(idx, fh, -1)
        fh.close()
        return idx, klass.modify_flat(f + klass.ext)


class NpyFastaRecord(FastaRecord):
    __slots__ = ('start', 'stop', 'mm', 'as_string')

    def __init__(self, mm, start, stop, as_string=True):
        self.mm = mm
        self.start = start
        self.stop = stop
        self.as_string = as_string

    def __repr__(self):
        return "%s(%i..%i)" % (self.__class__.__name__, self.start, self.stop)

    @classmethod
    def modify_flat(klass, flat_file):
        mm = numpy.memmap(flat_file, dtype="S1", mode="r")
        return mm

    def getdata(self, islice):
        if isinstance(islice, (int, long)):
            if islice >= 0:
                islice += self.start
            else:
                islice += self.stop
                if islice < 0:
                    raise IndexError
            return self.mm[islice]
        start, stop = self._adjust_slice(islice)
        return self.mm[start:stop:islice.step]

    def __getitem__(self, islice):
        d = self.getdata(islice)
        return d.tostring() if self.as_string else d

    @property
    def __array_interface__(self):
        return {'shape': (len(self), ),
                'typestr': '|S1',
                'version': 3,
                'data': self[:]}


class DuplicateHeaderException(Exception):

    def __init__(self, header):
        Exception.__init__(self, 'headers must be unique: %s is duplicated' % header)


class Fasta(dict):

    def __init__(self, fasta_name, record_class=NpyFastaRecord, flatten_inplace=False, key_fn=None):
        self.fasta_name = fasta_name
        self.record_class = record_class
        self.key_fn = key_fn
        self.index, self.prepared = self.record_class.prepare(self,
                                                              self.gen_seqs_with_headers(key_fn),
                                                              flatten_inplace)
        self.chr = {}

    def __contains__(self, key):
        return key in self.index

    def __getitem__(self, i):
        # Implements the lazy loading
        if self.key_fn is not None:
            i = self.key_fn(i)
        if i in self.chr:
            return self.chr[i]
        c = self.index[i]
        self.chr[i] = self.record_class(self.prepared, c[0], c[1])
        return self.chr[i]

    def __len__(self):
        return len(self.index)

    @classmethod
    def as_kmers(klass, seq, k, overlap=0):
        kmax = len(seq)
        assert overlap < k, ('overlap must be < kmer length')
        i = 0
        while i < kmax:
            yield i, seq[i:i + k]
            i += k - overlap

    def gen_seqs_with_headers(self, key_fn=None):
        """
        Remove all newlines from the sequence in a fasta file
        and generate starts, stops to be used by the record class.
        """
        fh = open(self.fasta_name, 'r')
        # Do the flattening (remove newlines)
        # check of unique-ness of headers.
        seen_headers = {}
        header = None
        seqs = None
        for line in fh:
            line = line.rstrip()
            if not line:
                continue
            if line[0] == ">":
                if seqs is not None:
                    if header in seen_headers:
                        raise DuplicateHeaderException(header)
                    seen_headers[header] = None
                    yield header, "".join(seqs)
                header = line[1:].strip()
                if key_fn is not None:
                    header = key_fn(header)
                seqs = []
            else:
                seqs.append(line)
        if seqs != []:
            if header in seen_headers:
                raise DuplicateHeaderException(header)
            yield header, "".join(seqs)
        fh.close()

    def iterkeys(self):
        for k in self.index.iterkeys():
            yield k

    def iteritems(self):
        for k in self.keys():
            yield k, self[k]

    def keys(self):
        return self.index.keys()

    def sequence(self, f, asstring=True, auto_rc=True, exon_keys=None, one_based=True):
        """
        Take a feature and use the start/stop or exon_keys to return the sequence from the
        associated fasta file.  By default, just return the full sequence between start
        and stop, but if exon_keys is set to an iterable, search for those keys and use the
        first to create a sequence and return the concatenated result.

        Note that sequence is 2 characters shorter than the entire feature, to account for
        the introns at base-pairs 12 and 16.

        Also note, this function looks for an item with key of 'rnas'.  If one is not found,
        it continues on to 'exons'. If it doesn't find any of the exon keys it will fall
        back on the start, stop of the feature:

        f: a feature, a feature can have exons.
        asstring: if true, return the sequence as a string, if false, return as a numpy array
        auto_rc : if True and the strand of the feature == -1, returnthe reverse complement of the sequence
        one_based: if true, query is using 1 based closed intervals, if false semi-open zero based intervals
        """
        assert 'chr' in f and f['chr'] in self, (f, f['chr'], self.keys())
        fasta = self[f['chr']]
        sequence = None
        if exon_keys is not None:
            sequence = self._seq_from_keys(f, fasta, exon_keys, one_based=one_based)
        if sequence is None:
            start = f['start'] - int(one_based)
            sequence = fasta[start: f['stop']]
        if auto_rc and f.get('strand') in (-1, '-1', '-'):
            sequence = COMPLEMENT(sequence)[::-1]
        if asstring:
            return sequence
        return numpy.array(sequence, dtype='c')

    def _seq_from_keys(self, f, fasta, exon_keys, base='locations', one_based=True):
        """
        f: a feature dict
        fasta: a Fasta object
        exon_keys: an iterable of keys, to look for start/stop arrays to get sequence.
        base: if base ('locations') exists, look there fore the exon_keys, not in the base
        of the object.
        """
        fbase = f.get(base, f)
        for ek in exon_keys:
            if ek not in fbase:
                continue
            locs = fbase[ek]
            seq = ""
            for start, stop in locs:
                seq += fasta[start - int(one_based):stop]
            return seq
        return None