view dpmix_plot.py @ 0:73648da53556 default tip

Uploaded
author rico
date Mon, 09 Apr 2012 11:55:36 -0400
parents
children
line wrap: on
line source

#!/usr/bin/env python

import os
import sys
import math
import matplotlib as mpl
mpl.use('PDF')
import matplotlib.pyplot as plt
from matplotlib.path import Path
import matplotlib.patches as patches

################################################################################

def build_chrom_len_dict(dbkey, galaxy_data_index_dir):
    chrom_len_root = os.path.join(galaxy_data_index_dir, 'shared/ucsc/chrom')
    chrom_len_file = '{0}.len'.format(dbkey)
    chrom_len_path = os.path.join(chrom_len_root, chrom_len_file)

    chrom_len = {}

    try:
        with open(chrom_len_path) as fh:
            for line in fh:
                line = line.rstrip('\r\n')
                elems = line.split()
                if len(elems) == 2:
                    chrom = elems[0]
                    length = int(elems[1])
                    chrom_len[chrom] = length
    except:
        pass

    return chrom_len

def parse_input_file(input_file):
    chroms = []
    individuals = []
    data = {}
    chrom_len = {}

    with open(input_file) as fh:
        for line in fh:
            line = line.strip()
            if line:
                elems = line.split()
                chrom = elems[0]
                p1, p2, state = map(int, elems[1:4])
                id = elems[4]

                if chrom not in chroms:
                    chroms.append(chrom)

                if id not in individuals:
                    individuals.append(id)

                data.setdefault(chrom, {})
                data[chrom].setdefault(id, [])
                data[chrom][id].append((p1, p2, state))

                if p2 > chrom_len.setdefault(chrom, 0):
                    chrom_len[chrom] = p2

    return chroms, individuals, data, chrom_len

def check_chroms(chroms, chrom_len, dbkey):
    error = 0
    for chrom in chroms:
        if chrom not in chrom_len:
            print >> sys.stderr, "Can't find length for {0} chromosome {1}".format(dbkey, chrom)
            error = 1
    if error:
        sys.exit(1)

def check_data(data, chrom_len, dbkey):
    error = 0
    for chrom in data:
        chrom_beg = 0
        chrom_end = chrom_len[chrom]
        for individual in data[chrom]:
            for p1, p2, state in data[chrom][individual]:
                if p1 >= p2:
                    print >> sys.stderr, "Bad data line: begin >= end: {0} {1} {2} {3}".format(chrom, p1, p2, state, individual)
                    error = 1
                if p1 < chrom_beg or p2 > chrom_end:
                    print >> sys.stderr, "Bad data line: outside {0} boundaries[{1} - {2}]: {3} {4} {5} {6}".format(dbkey, chrom_beg, chrom_end, chrom, p1, p2, state, individual)
                    error = 1
    if error:
        sys.exit(1)

def make_rectangle(p1, p2, color, bottom=0.0, top=1.0):
    verts = [
        (p1, bottom),   # left, bottom
        (p1, top),      # left, top
        (p2, top),      # right, top
        (p2, bottom),   # right, bottom
        (0.0, 0.0)      # ignored
    ]

    codes = [
        Path.MOVETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.CLOSEPOLY
    ]

    path = Path(verts, codes)
    return patches.PathPatch(path, facecolor=color, lw=0)

def make_split_rectangle(p1, p2, top_color, bottom_color):
    patch1 = make_rectangle(p1, p2, bottom_color, top=0.5)
    patch2 = make_rectangle(p1, p2, top_color, bottom=0.5)
    return [patch1, patch2]

def make_state_rectangle(p1, p2, state, chrom, individual):
    if state == 0:
        return [ make_rectangle(p1, p2, 'r') ]
    elif state == 1:
        return make_split_rectangle(p1, p2, 'r', 'g')
    elif state == 2:
        return [ make_rectangle(p1, p2, 'g') ]
    else:
        print >> sys.stderr, "Unknown state: {0}: {1} {2} {3} {4}".format(state, chrom, p1, p2, state, individual)
        sys.exit(1)

def nicenum(num, round=False):
    if num == 0:
        return 0.0

    exp = int(math.floor(math.log10(num)))
    f = num / math.pow(10, exp)

    if round:
        if f < 1.5:
            nf = 1.0
        elif f < 3.0:
            nf = 2.0
        elif f < 7.0:
            nf = 5.0
        else:
            nf = 10.0
    else:
        if f <= 1.0:
            nf = 1.0
        elif f <= 2.0:
            nf = 2.0
        elif f <= 5.0:
            nf = 5.0
        else:
            nf = 10.0

    return nf * pow(10, exp)

def tick_foo(beg, end, loose=False):
    ntick = 10

    range = nicenum(end - beg, round=False)
    d = nicenum(range/(ntick - 1), round=True)
    digits = int(math.floor(math.log10(d)))

    if loose:
        graph_min = math.floor(beg/d) * d
        graph_max = math.ceil(end/d) * d
    else:
        graph_min = beg
        graph_max = end

    nfrac = max([-1 * digits, 0])
    vals = []

    stop = graph_max
    if loose:
        stop = graph_max + (0.5 * d)

    x = graph_min
    while x <= stop:
        vals.append(int(x))
        x += d

    vals = vals[1:]

#    if not loose:
#        if vals[-1] < graph_max:
#            vals.append(int(graph_max))

    labels = []
    for val in vals:
        labels.append('{0}'.format(int(val/math.pow(10, digits))))

#   labels.append('{0:.1f}'.format(vals[-1]/math.pow(10, digits)))

    return vals, labels

################################################################################

def make_dpmix_plot(input_dbkey, input_file, output_file, galaxy_data_index_dir):
    fs_chrom_len = build_chrom_len_dict(input_dbkey, galaxy_data_index_dir)
    chroms, individuals, data, chrom_len = parse_input_file(input_file)

    for chrom in chrom_len.keys():
        if chrom in fs_chrom_len:
            chrom_len[chrom] = fs_chrom_len[chrom]

    #check_chroms(chroms, chrom_len, input_dbkey)
    check_data(data, chrom_len, input_dbkey)

    ## units below are inches
    top_space = 0.10
    chrom_space = 0.25
    chrom_height = 0.25
    ind_space = 0.10
    ind_height = 0.25

    total_height = 0.0
    at_top = True
    for chrom in chroms:
        if at_top:
            total_height += (top_space + chrom_height)
            at_top = False
        else:
            total_height += (top_space + chrom_space + chrom_height)
    
        individual_count = 0
        for individual in individuals:
            if individual in data[chrom]:
                individual_count += 1
        total_height += individual_count * (ind_space + ind_height)

    width = 7.5
    height = math.ceil(total_height)

    bottom = 1.0

    fig = plt.figure(figsize=(width, height))

    at_top = True
    for_webb = False

    for chrom in chroms:
        length = chrom_len[chrom]
        vals, labels = tick_foo(0, length)

        if at_top:
            bottom -= (top_space + chrom_height)/height
            at_top = False
        else:
            bottom -= (top_space + chrom_space + chrom_height)/height

        if not for_webb:
            ax = fig.add_axes([0.0, bottom, 1.0, chrom_height/height])
            plt.axis('off')
            plt.text(0.5, 0.5, chrom, fontsize=14, ha='center')

        individual_count = 0
        for individual in individuals:
            if individual in data[chrom]:
                individual_count += 1

        i = 0
        for individual in individuals:
            if individual in data[chrom]:
                i += 1

                bottom -= (ind_space + ind_height)/height
                if not for_webb:
                    # [left, bottom, width, height]
                    ax1 = fig.add_axes([0.0, bottom, 0.09, ind_height/height])
                    plt.axis('off')
                    plt.text(1.0, 0.5, individual, fontsize=10, ha='right', va='center')
                # [left, bottom, width, height]
                ax2 = fig.add_axes([0.10, bottom, 0.88, ind_height/height], frame_on=False)
                ax2.set_xlim(0, length)
                ax2.set_ylim(0, 1)
                if i != individual_count:
                    plt.axis('off')
                else:
                    if not for_webb:
                        ax2.tick_params(top=False, left=False, right=False, labelleft=False)
                        ax2.set_xticks(vals)
                        ax2.set_xticklabels(labels)
                    else:
                        plt.axis('off')
                for p1, p2, state in sorted(data[chrom][individual]):
                    for patch in make_state_rectangle(p1, p2, state, chrom, individual):
                        ax2.add_patch(patch)

    plt.savefig(output_file)

################################################################################

if __name__ == '__main__':
    input_dbkey, input_file, output_file, galaxy_data_index_dir = sys.argv[1:5]
    make_dpmix_plot(input_dbkey, input_file, output_file, galaxy_data_index_dir)
    sys.exit(0)