diff hclust2/hclust2.py @ 0:cac6247cb1d3 draft

graphlan_import
author george-weingart
date Tue, 26 Aug 2014 14:51:29 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/hclust2/hclust2.py	Tue Aug 26 14:51:29 2014 -0400
@@ -0,0 +1,758 @@
+#!/usr/bin/env python
+
+import sys
+import numpy as np
+import matplotlib.ticker as ticker
+import scipy.spatial.distance as spd 
+import scipy.cluster.hierarchy as sph
+from scipy import stats
+import matplotlib
+#matplotlib.use('Agg')
+import pylab
+import pandas as pd
+from matplotlib.patches import Rectangle
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+import cPickle as pickle
+
+sys.setrecursionlimit(10000)
+
+# samples on rows
+
+class SqrtNorm(matplotlib.colors.Normalize):
+    """
+    Normalize a given value to the 0-1 range on a square root scale
+    """
+    def __call__(self, value, clip=None):
+        if clip is None:
+            clip = self.clip
+
+        result, is_scalar = self.process_value(value)
+
+        result = np.ma.masked_less_equal(result, 0, copy=False)
+
+        self.autoscale_None(result)
+        vmin, vmax = self.vmin, self.vmax
+        if vmin > vmax:
+            raise ValueError("minvalue must be less than or equal to maxvalue")
+        elif vmin <= 0:
+            raise ValueError("values must all be positive")
+        elif vmin == vmax:
+            result.fill(0)
+        else:
+            if clip:
+                mask = np.ma.getmask(result)
+                result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
+                                  mask=mask)
+            # in-place equivalent of above can be much faster
+            resdat = result.data
+            mask = result.mask
+            if mask is np.ma.nomask:
+                mask = (resdat <= 0)
+            else:
+                mask |= resdat <= 0
+            matplotlib.cbook._putmask(resdat, mask, 1)
+            np.sqrt(resdat, resdat)
+            resdat -= np.sqrt(vmin)
+            resdat /= (np.sqrt(vmax) - np.sqrt(vmin))
+            result = np.ma.array(resdat, mask=mask, copy=False)
+        if is_scalar:
+            result = result[0]
+        return result
+
+    def inverse(self, value):
+        if not self.scaled():
+            raise ValueError("Not invertible until scaled")
+        vmin, vmax = self.vmin, self.vmax
+
+        if matplotlib.cbook.iterable(value):
+            val = np.ma.asarray(value)
+            return vmin * np.ma.power((vmax / vmin), val)
+        else:
+            return vmin * pow((vmax / vmin), value)
+
+    def autoscale(self, A):
+        '''
+        Set *vmin*, *vmax* to min, max of *A*.
+        '''
+        A = np.ma.masked_less_equal(A, 0, copy=False)
+        self.vmin = np.ma.min(A)
+        self.vmax = np.ma.max(A)
+
+    def autoscale_None(self, A):
+        ' autoscale only None-valued vmin or vmax'
+        if self.vmin is not None and self.vmax is not None:
+            return
+        A = np.ma.masked_less_equal(A, 0, copy=False)
+        if self.vmin is None:
+            self.vmin = np.ma.min(A)
+        if self.vmax is None:
+            self.vmax = np.ma.max(A)
+
+class DataMatrix:
+    datatype = 'data_matrix'
+    
+    @staticmethod
+    def input_parameters( parser ):
+        dm_param = parser.add_argument_group('Input data matrix parameters')
+        arg = dm_param.add_argument
+
+        arg( '--sep', type=str, default='\t' )
+        arg( '--out_table', type=str, default=None,
+             help = 'Write processed data matrix to file' )
+        arg( '--fname_row', type=int, default=0,
+             help = "row number containing the names of the features "
+                    "[default 0, specify -1 if no names are present in the matrix")
+        arg( '--sname_row', type=int, default=0,
+             help = "column number containing the names of the samples "
+                    "[default 0, specify -1 if no names are present in the matrix")
+        arg( '--metadata_rows', type=str, default=None,
+             help = "Row numbers to use as metadata"
+                    "[default None, meaning no metadata")
+        arg( '--skip_rows', type=str, default=None,
+             help = "Row numbers to skip (0-indexed, comma separated) from the input file"
+                    "[default None, meaning no rows skipped")
+        arg( '--sperc', type=int, default=90, 
+             help = "Percentile of sample value distribution for sample selection" )
+        arg( '--fperc', type=int, default=90, 
+             help = "Percentile of feature value distribution for sample selection" )
+        arg( '--stop', type=int, default=None, 
+             help = "Number of top samples to select (ordering based on percentile specified by --sperc)" )
+        arg( '--ftop', type=int, default=None, 
+             help = "Number of top features to select (ordering based on percentile specified by --fperc)" )
+        arg( '--def_na', type=float, default=None,
+             help = "Set the default value for missing values [default None which means no replacement]")
+
+    def __init__( self, input_file, args ):
+        self.args = args
+        self.metadata_rows =  [] 
+        self.metadata_table = None
+        toskip = [int(l) for l in self.args.skip_rows.split(",")]  if self.args.skip_rows else [] 
+        if self.args.metadata_rows:
+            self.metadata_rows = list([int(a) for a in self.args.metadata_rows.split(",")])
+            mdr = self.metadata_rows[::]
+            for t in toskip:
+                for i,m in enumerate(mdr):
+                    if t <= m:
+                        self.metadata_rows[i] -= 1
+        if self.metadata_rows:
+            header = [self.args.fname_row]+self.metadata_rows if self.args.fname_row > -1 else self.metadata_rows
+        else:
+            header = self.args.fname_row if self.args.fname_row > -1 else None        
+        self.table = pd.read_table( 
+                input_file, sep = self.args.sep, # skipinitialspace = True, 
+                                  skiprows = sorted(toskip) if isinstance(toskip, list) else toskip,
+                                  header = sorted(header) if isinstance(header, list) else header,
+                                  index_col = self.args.sname_row if self.args.sname_row > -1 else None
+                                    )
+        
+        def select( perc, top  ): 
+            self.table['perc'] = self.table.apply(lambda x: stats.scoreatpercentile(x,perc),axis=1)
+            m = sorted(self.table['perc'])[-top]
+            self.table = self.table[self.table['perc'] >= m ]
+            del self.table['perc'] 
+        
+        if not self.args.def_na is None:
+            self.table = self.table.fillna( self.args.def_na )
+
+        if self.args.ftop:
+            select( self.args.fperc, self.args.ftop )
+        
+        if self.args.stop:
+            self.table = self.table.T 
+            select( self.args.sperc, self.args.stop ) 
+            self.table = self.table.T
+        
+
+        # add missing values
+        
+    def get_numpy_matrix( self ): 
+        return np.matrix(self.table)
+    
+    #def get_metadata_matrix( self ):
+    #    return self.table.columns
+    
+    def get_snames( self ):
+        #return list(self.table.index)
+        return self.table.columns
+    
+    def get_fnames( self ):
+        #print self.table.columns.names
+        #print self.table.columns
+        return list(self.table.index)
+    
+    def get_averages(self, by_row = True) :
+        return self.table.mean(axis = 1 if by_row else 0)
+    
+    def save_matrix( self, output_file ):
+        self.table.to_csv( output_file, sep = '\t' )
+
+class DistMatrix:
+    datatype = 'distance_matrix'
+
+    @staticmethod
+    def input_parameters( parser ):
+        dm_param = parser.add_argument_group('Distance parameters')
+        arg = dm_param.add_argument
+
+        dist_funcs = [  "euclidean","minkowski","cityblock","seuclidean",
+                        "sqeuclidean","cosine","correlation","hamming",
+                        "jaccard","chebyshev","canberra","braycurtis",
+                        "mahalanobis","yule","matching","dice",
+                        "kulsinski","rogerstanimoto","russellrao","sokalmichener",
+                        "sokalsneath","wminkowski","ward" ]
+
+        arg( '--f_dist_f', type=str, default="correlation",
+             help = "Distance function for features [default correlation]")
+        arg( '--s_dist_f', type=str, default="euclidean",
+             help = "Distance function for sample [default euclidean]")
+        arg( '--load_dist_matrix_f', type=str, default=None,
+             help = "Load the distance matrix to be used for features [default None].")
+        arg( '--load_dist_matrix_s', type=str, default=None,
+             help = "Load the distance matrix to be used for samples [default None].")
+        arg( '--save_dist_matrix_f', type=str, default=None,
+             help = "Save the distance matrix for features to file [default None].")
+        arg( '--save_dist_matrix_s', type=str, default=None,
+             help = "Save the distance matrix for samples to file [default None].")
+    
+    def __init__( self, data, args = None ):
+        self.sdf = args.s_dist_f
+        self.fdf = args.f_dist_f
+
+        self.s_cdist_matrix, self.f_cdist_matrix = None, None
+
+        self.numpy_full_matrix = (data if 
+                type(data) == np.matrixlib.defmatrix.matrix else None)
+    
+    def compute_f_dists( self ):
+        if args.load_dist_matrix_f:
+            with open( args.load_dist_matrix_f ) as inp:
+                self.f_cdist_matrix = pickle.load( inp )
+
+        else:
+            dt = self.numpy_full_matrix
+            
+            if self.fdf == "spearman":
+                dt_ranked = np.matrix([stats.rankdata(d) for d in dt])
+                self.f_cdist_matrix = spd.pdist( dt_ranked, "correlation" )
+                return
+        
+            if self.fdf == "pearson":
+                self.fdf = 'correlation'
+
+            self.f_cdist_matrix = spd.pdist( dt, self.fdf )
+
+        if args.save_dist_matrix_f:
+            with open( args.save_dist_matrix_f, "wb" ) as outf:
+                pickle.dump( self.f_cdist_matrix, outf )
+    
+    def compute_s_dists( self ):
+        if args.load_dist_matrix_s:
+            with open( args.load_dist_matrix_s ) as inp:
+                self.s_cdist_matrix = pickle.load( inp )
+        else: 
+            dt = self.numpy_full_matrix.transpose()
+            
+            if self.sdf == "spearman":
+                dt_ranked = np.matrix([stats.rankdata(d) for d in dt])
+                self.s_cdist_matrix = spd.pdist( dt_ranked, "correlation" )
+                return
+            
+            if self.sdf == "pearson":
+                self.sdf = 'correlation'
+        
+            self.s_cdist_matrix = spd.pdist( dt, self.sdf )
+        
+        if args.save_dist_matrix_s:
+            with open( args.save_dist_matrix_s, "wb" ) as outf:
+                pickle.dump( self.s_cdist_matrix, outf )
+
+    def get_s_dm( self ):
+        return self.s_cdist_matrix
+
+    def get_f_dm( self ):
+        return self.f_cdist_matrix
+
+class HClustering:
+    datatype = 'hclustering'
+
+    @staticmethod
+    def input_parameters( parser ):
+        cl_param = parser.add_argument_group('Clustering parameters')
+        arg = cl_param.add_argument
+
+        linkage_method = [ "single","complete","average", 
+                           "weighted","centroid","median",
+                           "ward" ]
+        arg( '--no_fclustering', action='store_true',
+             help = "avoid clustering features" )
+        arg( '--no_sclustering', action='store_true',
+             help = "avoid clustering samples" )
+        arg( '--flinkage', type=str, default="average",
+             help = "Linkage method for feature clustering [default average]")
+        arg( '--slinkage', type=str, default="average",
+             help = "Linkage method for sample clustering [default average]")
+
+    def get_reordered_matrix( self, matrix, sclustering = True, fclustering = True ):
+        if not sclustering and not fclustering:
+            return matrix
+        
+        idx1 = self.sdendrogram['leaves'] if sclustering else None   # !!!!!!!!!!!
+        idx2 = self.fdendrogram['leaves'][::-1] if fclustering else None
+
+        if sclustering and fclustering:
+            return matrix[idx2,:][:,idx1]
+        if fclustering:
+            return matrix[idx2,:][:]
+        if sclustering: # !!!!!!!!!!!!
+            return matrix[:][:,idx1]
+
+    def get_reordered_sample_labels( self, slabels ):
+        return [slabels[i] for i in self.sdendrogram['leaves']]
+
+    def get_reordered_feature_labels( self, flabels ):
+        return [flabels[i] for i in self.fdendrogram['leaves']]
+    
+    def __init__( self, s_dm, f_dm, args = None ):
+        self.s_dm = s_dm
+        self.f_dm = f_dm
+        self.args = args
+        self.sclusters = None
+        self.fclusters = None
+        self.sdendrogram = None
+        self.fdendrogram = None
+
+    def shcluster( self, dendrogram = True ):
+        self.shclusters = sph.linkage( self.s_dm, args.slinkage ) 
+        if dendrogram:
+            self.sdendrogram = sph.dendrogram( self.shclusters, no_plot=True )
+
+    def fhcluster( self, dendrogram = True ):
+        self.fhclusters = sph.linkage( self.f_dm, args.flinkage ) 
+        if dendrogram:
+            self.fdendrogram = sph.dendrogram( self.fhclusters, no_plot=True )
+    
+    def get_shclusters( self ):
+        return self.shclusters
+    
+    def get_fhclusters( self ):
+        return self.fhclusters
+    
+    def get_sdendrogram( self ):
+        return self.sdendrogram
+    
+    def get_fdendrogram( self ):
+        return self.fdendrogram
+
+
+class Heatmap:
+    datatype = 'heatmap'
+   
+    bbcyr = {'red':  (  (0.0, 0.0, 0.0),
+                        (0.25, 0.0, 0.0),
+                        (0.50, 0.0, 0.0),
+                        (0.75, 1.0, 1.0),
+                        (1.0, 1.0, 1.0)),
+             'green': ( (0.0, 0.0, 0.0),
+                        (0.25, 0.0, 0.0),
+                        (0.50, 1.0, 1.0),
+                        (0.75, 1.0, 1.0),
+                        (1.0, 0.0, 1.0)),
+             'blue': (  (0.0, 0.0, 0.0),
+                        (0.25, 1.0, 1.0),
+                        (0.50, 1.0, 1.0),
+                        (0.75, 0.0, 0.0),
+                        (1.0, 0.0, 1.0))}
+
+    bbcry = {'red':  (  (0.0, 0.0, 0.0),
+                        (0.25, 0.0, 0.0),
+                        (0.50, 0.0, 0.0),
+                        (0.75, 1.0, 1.0),
+                        (1.0, 1.0, 1.0)),
+             'green': ( (0.0, 0.0, 0.0),
+                        (0.25, 0.0, 0.0),
+                        (0.50, 1.0, 1.0),
+                        (0.75, 0.0, 0.0),
+                        (1.0, 1.0, 1.0)),
+             'blue': (  (0.0, 0.0, 0.0),
+                        (0.25, 1.0, 1.0),
+                        (0.50, 1.0, 1.0),
+                        (0.75, 0.0, 0.0),
+                        (1.0, 0.0, 1.0))}
+
+    bcry = {'red':  (   (0.0, 0.0, 0.0),
+                        (0.33, 0.0, 0.0),
+                        (0.66, 1.0, 1.0),
+                        (1.0, 1.0, 1.0)),
+             'green': ( (0.0, 0.0, 0.0),
+                        (0.33, 1.0, 1.0),
+                        (0.66, 0.0, 0.0),
+                        (1.0, 1.0, 1.0)),
+             'blue': (  (0.0, 1.0, 1.0),
+                        (0.33, 1.0, 1.0),
+                        (0.66, 0.0, 0.0),
+                        (1.0, 0.0, 1.0))}
+    
+
+    my_colormaps = [    ('bbcyr',bbcyr),
+                        ('bbcry',bbcry),
+                        ('bcry',bcry)]
+   
+    dcols = ['#ca0000','#0087ff','#00ba1d','#cf00ff','#00dbe2','#ffaf00','#0017f4','#006012','#e175ff','#877878','#050505','#b5cf00','#ff8a8a','#aa6400','#50008a','#00ff58']
+
+
+    @staticmethod
+    def input_parameters( parser ):
+        hm_param = parser.add_argument_group('Heatmap options')
+        arg = hm_param.add_argument
+
+        arg( '--dpi', type=int, default=150,
+             help = "Image resolution in dpi [default 150]")
+        arg( '-l', '--log_scale', action='store_true',
+             help = "Log scale" )
+        arg( '-s', '--sqrt_scale', action='store_true',
+             help = "Square root scale" )
+        arg( '--no_slabels', action='store_true',
+             help = "Do not show sample labels" )
+        arg( '--minv', type=float, default=None,
+             help = "Minimum value to display in the color map [default None meaning automatic]" )
+        arg( '--maxv', type=float, default=None,
+             help = "Maximum value to display in the color map [default None meaning automatic]" )
+        arg( '--no_flabels', action='store_true',
+             help = "Do not show feature labels" )
+        arg( '--max_slabel_len', type=int, default=25,
+             help = "Max number of chars to report for sample labels [default 15]" )
+        arg( '--max_flabel_len', type=int, default=25,
+             help = "Max number of chars to report for feature labels [default 15]" )
+        arg( '--flabel_size', type=int, default=10,
+             help = "Feature label font size [default 10]" )
+        arg( '--slabel_size', type=int, default=10,
+             help = "Sample label font size [default 10]" )
+        arg( '--fdend_width', type=float, default=1.0,
+             help = "Width of the feature dendrogram [default 1 meaning 100%% of default heatmap width]")
+        arg( '--sdend_height', type=float, default=1.0,
+             help = "Height of the sample dendrogram [default 1 meaning 100%% of default heatmap height]")
+        arg( '--metadata_height', type=float, default=.05,
+             help = "Height of the metadata panel [default 0.05 meaning 5%% of default heatmap height]")
+        arg( '--metadata_separation', type=float, default=.01,
+             help = "Distance between the metadata and data panels. [default 0.001 meaning 0.1%% of default heatmap height]")
+        arg( '--image_size', type=float, default=8,
+             help = "Size of the largest between width and eight size for the image in inches [default 8]")
+        arg( '--cell_aspect_ratio', type=float, default=1.0,
+             help = "Aspect ratio between width and height for the cells of the heatmap [default 1.0]")
+        col_maps = ['Accent', 'Blues', 'BrBG', 'BuGn', 'BuPu', 'Dark2', 'GnBu',
+                    'Greens', 'Greys', 'OrRd', 'Oranges', 'PRGn', 'Paired',
+                    'Pastel1', 'Pastel2', 'PiYG', 'PuBu', 'PuBuGn', 'PuOr',
+                    'PuRd', 'Purples', 'RdBu', 'RdGy', 'RdPu', 'RdYlBu', 'RdYlGn',
+                    'Reds', 'Set1', 'Set2', 'Set3', 'Spectral', 'YlGn', 'YlGnBu',
+                    'YlOrBr', 'YlOrRd', 'afmhot', 'autumn', 'binary', 'bone',
+                    'brg', 'bwr', 'cool', 'copper', 'flag', 'gist_earth',
+                    'gist_gray', 'gist_heat', 'gist_ncar', 'gist_rainbow',
+                    'gist_stern', 'gist_yarg', 'gnuplot', 'gnuplot2', 'gray',
+                    'hot', 'hsv', 'jet', 'ocean', 'pink', 'prism', 'rainbow',
+                    'seismic', 'spectral', 'spring', 'summer', 'terrain', 'winter'] + [n for n,c in Heatmap.my_colormaps]
+        for n,c in Heatmap.my_colormaps:
+            my_cmap = matplotlib.colors.LinearSegmentedColormap(n,c,256)
+            pylab.register_cmap(name=n,cmap=my_cmap)
+        arg( '-c','--colormap', type=str, choices = col_maps, default = 'bbcry' )
+        arg( '--bottom_c', type=str, default = None,
+             help = "Color to use for cells below the minimum value of the scale [default None meaning bottom color of the scale]")
+        arg( '--top_c', type=str, default = None,
+             help = "Color to use for cells below the maximum value of the scale [default None meaning bottom color of the scale]")
+        arg( '--nan_c', type=str, default = None,
+             help = "Color to use for nan cells  [default None]")
+
+        
+
+        """
+        arg( '--', type=str, default="average",
+             help = "Linkage method for feature clustering [default average]")
+        arg( '--slinkage', type=str, default="average",
+             help = "Linkage method for sample clustering [default average]")
+        """
+
+    def __init__( self, numpy_matrix, sdendrogram, fdendrogram, snames, fnames, fnames_meta, args = None ):
+        self.numpy_matrix = numpy_matrix
+        self.sdendrogram = sdendrogram
+        self.fdendrogram = fdendrogram
+        self.snames = snames
+        self.fnames = fnames
+        self.fnames_meta = fnames_meta
+        self.ns,self.nf = self.numpy_matrix.shape
+        self.args = args
+
+    def make_legend( self, dmap, titles, out_fn ): 
+        figlegend = plt.figure(figsize=(1+3*len(titles),2), frameon = False)
+
+        gs = gridspec.GridSpec( 1, len(dmap), wspace = 2.0  )
+
+        for i,(d,title) in enumerate(zip(dmap,titles)):
+            legax = plt.subplot(gs[i],frameon = False)
+            for k,v in sorted(d.items(),key=lambda x:x[1]):
+                rect = Rectangle( [0.0, 0.0], 0.0, 0.0,
+                                  facecolor = self.dcols[v%len(self.dcols)],
+                                  label = k,
+                                  edgecolor='b', lw = 0.0)
+
+                legax.add_patch(rect)
+        #remove_splines( legax )
+            legax.set_xticks([])
+            legax.set_yticks([])
+            legax.legend( loc = 2, frameon = False, title = title) 
+        """
+                      ncol = legend_ncol, bbox_to_anchor=(1.01, 3.),
+                      borderpad = 0.0, labelspacing = 0.0,
+                      handlelength = 0.5, handletextpad = 0.3,
+                      borderaxespad = 0.0, columnspacing = 0.3,
+                      prop = {'size':fontsize}, frameon = False)
+        """
+        if out_fn:
+            figlegend.savefig(out_fn, bbox_inches='tight')
+    
+    def draw( self ):
+
+        rat = float(self.ns)/self.nf
+        rat *= self.args.cell_aspect_ratio
+        x,y = (self.args.image_size,rat*self.args.image_size) if rat < 1 else (self.args.image_size/rat,self.args.image_size)
+        fig = plt.figure( figsize=(x,y), facecolor = 'w'  )
+
+        cm = pylab.get_cmap(self.args.colormap)
+        bottom_col = [  cm._segmentdata['red'][0][1],
+                        cm._segmentdata['green'][0][1],
+                        cm._segmentdata['blue'][0][1]   ]
+        if self.args.bottom_c:
+            bottom_col = self.args.bottom_c
+        cm.set_under( bottom_col )
+        top_col = [  cm._segmentdata['red'][-1][1],
+                     cm._segmentdata['green'][-1][1],
+                     cm._segmentdata['blue'][-1][1]   ]
+        if self.args.top_c:
+            top_col = self.args.top_c
+        cm.set_over( top_col )
+
+        if self.args.nan_c:
+            cm.set_bad( self.args.nan_c  )
+
+        def make_ticklabels_invisible(ax):
+            for tl in ax.get_xticklabels() + ax.get_yticklabels():
+                 tl.set_visible(False)
+            ax.set_xticks([])
+            ax.set_yticks([])
+      
+        def remove_splines( ax ):
+            for v in ['right','left','top','bottom']:
+                ax.spines[v].set_color('none')
+
+        def shrink_labels( labels, n ):
+            shrink = lambda x: x[:n/2]+" [...] "+x[-n/2:]
+            return [(shrink(str(l)) if len(str(l)) > n else l) for l in labels]
+        
+
+        #gs = gridspec.GridSpec( 4, 2, 
+        #                        width_ratios=[1.0-fr_ns,fr_ns], 
+        #                        height_ratios=[.03,0.03,1.0-fr_nf,fr_nf], 
+        #                        wspace = 0.0, hspace = 0.0 )
+        
+        fr_ns = float(self.ns)/max([self.ns,self.nf])
+        fr_nf = float(self.nf)/max([self.ns,self.nf])
+       
+        buf_space = 0.05
+        minv = min( [buf_space*8, 8*rat*buf_space] )
+        if minv < 0.05:
+            buf_space /= minv/0.05
+        metadata_height = self.args.metadata_height if type(snames[0]) is tuple and len(snames[0]) > 1 else 0.000001 
+        gs = gridspec.GridSpec( 6, 4, 
+                                width_ratios=[ buf_space, buf_space*2, .08*self.args.fdend_width,0.9], 
+                                height_ratios=[ buf_space, buf_space*2, .08*self.args.sdend_height, metadata_height, self.args.metadata_separation, 0.9], 
+                                wspace = 0.0, hspace = 0.0 )
+
+        ax_hm = plt.subplot(gs[23], axisbg = bottom_col  )
+        ax_metadata = plt.subplot(gs[15], axisbg = bottom_col  )
+        ax_hm_y2 = ax_hm.twinx() 
+
+        norm_f = matplotlib.colors.Normalize
+        if self.args.log_scale:
+            norm_f = matplotlib.colors.LogNorm
+        elif self.args.sqrt_scale:
+            norm_f = SqrtNorm
+        minv, maxv = 0.0, None
+
+        maps, values, ndv = [], [], 0
+        if type(snames[0]) is tuple and len(snames[0]) > 1:
+            metadata = zip(*[list(s[1:]) for s in snames])
+            for m in metadata:
+                mmap = dict([(v[1],ndv+v[0]) for v in enumerate(list(set(m)))])
+                values.append([mmap[v] for v in m])
+                ndv += len(mmap)
+                maps.append(mmap)
+            dcols = [] 
+            mdmat = np.matrix(values)
+            while len(dcols) < ndv:
+                dcols += self.dcols
+            cmap = matplotlib.colors.ListedColormap(dcols[:ndv]) 
+            bounds = [float(f)-0.5 for f in range(ndv+1)]
+            imm = ax_metadata.imshow( mdmat, #origin='lower', 
+                    interpolation = 'nearest',  
+                                    aspect='auto', 
+                                    extent = [0, self.nf, 0, self.ns], 
+                                    cmap=cmap,
+                                    vmin=bounds[0],
+                                    vmax=bounds[-1],
+                                    )
+            remove_splines( ax_metadata )
+            ax_metadata_y2 = ax_metadata.twinx() 
+            ax_metadata_y2.set_ylim(0,len(self.fnames_meta))
+            ax_metadata.set_yticks([])
+            ax_metadata_y2.set_ylim(0,len(self.fnames_meta))
+            ax_metadata_y2.tick_params(length=0)
+            ax_metadata_y2.set_yticks(np.arange(len(self.fnames_meta))+0.5)
+            ax_metadata_y2.set_yticklabels(self.fnames_meta[::-1], va='center',size=self.args.flabel_size)
+        else:
+            ax_metadata.set_yticks([])
+
+        ax_metadata.set_xticks([])
+        
+        im = ax_hm.imshow( self.numpy_matrix, #origin='lower', 
+                                interpolation = 'nearest',  aspect='auto', 
+                                extent = [0, self.nf, 0, self.ns], 
+                                cmap=cm, 
+                                vmin=self.args.minv,
+                                vmax=self.args.maxv, 
+                                norm = norm_f( vmin=minv if minv > 0.0 else None, vmax=maxv)
+                                )
+        
+        #ax_hm.set_ylim([0,800])
+        ax_hm.set_xticks(np.arange(len(list(snames)))+0.5)
+        if not self.args.no_slabels:
+            snames_short = shrink_labels( list([s[0] for s in snames]) if type(snames[0]) is tuple else snames, self.args.max_slabel_len )
+            ax_hm.set_xticklabels(snames_short,rotation=90,va='top',ha='center',size=self.args.slabel_size)
+        else:
+            ax_hm.set_xticklabels([])
+        ax_hm_y2.set_ylim([0,self.ns])
+        ax_hm_y2.set_yticks(np.arange(len(fnames))+0.5)
+        if not self.args.no_flabels:
+            fnames_short = shrink_labels( fnames, self.args.max_flabel_len )
+            ax_hm_y2.set_yticklabels(fnames_short,va='center',size=self.args.flabel_size)
+        else:
+            ax_hm_y2.set_yticklabels( [] )
+        ax_hm.set_yticks([])
+        remove_splines( ax_hm )
+        ax_hm.tick_params(length=0)
+        ax_hm_y2.tick_params(length=0)
+        #ax_hm.set_xlim([0,self.ns])
+        ax_cm = plt.subplot(gs[3], axisbg = 'r', frameon = False)
+        #fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing = 'proportional', format = ticker.LogFormatterMathtext() )
+        fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing='proportional' if self.args.sqrt_scale else 'uniform' ) # , format = ticker.LogFormatterMathtext() )
+
+        if not self.args.no_sclustering:
+            ax_den_top = plt.subplot(gs[11], axisbg = 'r', frameon = False)
+            sph._plot_dendrogram( self.sdendrogram['icoord'], self.sdendrogram['dcoord'], self.sdendrogram['ivl'],
+                                  self.ns + 1, self.nf + 1, 1, 'top', no_labels=True,
+                                  color_list=self.sdendrogram['color_list'] )
+            ymax = max([max(a) for a in self.sdendrogram['dcoord']])
+            ax_den_top.set_ylim([0,ymax])
+            make_ticklabels_invisible( ax_den_top )
+        if not self.args.no_fclustering:
+            ax_den_right = plt.subplot(gs[22], axisbg = 'b', frameon = False)
+            sph._plot_dendrogram(   self.fdendrogram['icoord'], self.fdendrogram['dcoord'], self.fdendrogram['ivl'],
+                                    self.ns + 1, self.nf + 1, 1, 'right', no_labels=True,
+                                    color_list=self.fdendrogram['color_list'] )
+            xmax = max([max(a) for a in self.fdendrogram['dcoord']])
+            ax_den_right.set_xlim([xmax,0])
+            make_ticklabels_invisible( ax_den_right )
+
+        
+        if not self.args.out:
+            plt.show( )
+        else:
+            fig.savefig( self.args.out, bbox_inches='tight', dpi = self.args.dpi )
+            if maps: 
+                self.make_legend( maps, fnames_meta, self.args.legend_file ) 
+
+
+
+class ReadCmd:
+
+    def __init__( self ):
+        import argparse as ap
+        import textwrap
+
+        p = ap.ArgumentParser( description= "TBA" )
+        arg = p.add_argument
+        
+        arg( '-i', '--inp', '--in', metavar='INPUT_FILE', type=str, nargs='?', default=sys.stdin,
+             help= "The input matrix" )
+        arg( '-o', '--out', metavar='OUTPUT_FILE', type=str, nargs='?', default=None,
+             help= "The output image file [image on screen of not specified]" )
+        arg( '--legend_file', metavar='LEGEND_FILE', type=str, nargs='?', default=None,
+             help= "The output file for the legend of the provided metadata" )
+
+        input_types = [DataMatrix.datatype,DistMatrix.datatype]
+        arg( '-t', '--input_type', metavar='INPUT_TYPE', type=str, choices = input_types, 
+             default='data_matrix',
+             help= "The input type can be a data matrix or distance matrix [default data_matrix]" )
+
+        DataMatrix.input_parameters( p )
+        DistMatrix.input_parameters( p )
+        HClustering.input_parameters( p )
+        Heatmap.input_parameters( p )
+
+        self.args  = p.parse_args()
+
+    def check_consistency( self ):
+        pass
+
+    def get_args( self ):
+        return self.args
+
+if __name__ == '__main__':
+     
+    read = ReadCmd( )
+    read.check_consistency()
+    args = read.get_args()
+    
+    if args.input_type == DataMatrix.datatype:
+        dm = DataMatrix( args.inp, args ) 
+        if args.out_table:
+            dm.save_matrix( args.out_table )
+        
+        distm = DistMatrix( dm.get_numpy_matrix(), args = args )
+        if not args.no_sclustering:
+            distm.compute_s_dists()
+        if not args.no_fclustering:
+            distm.compute_f_dists()
+    elif args.input_type == DataMatrix.datatype:
+        # distm = read...
+        pass
+    else:
+        pass
+
+    cl = HClustering( distm.get_s_dm(), distm.get_f_dm(), args = args )
+    if not args.no_sclustering:
+        cl.shcluster()
+    if not args.no_fclustering:
+        cl.fhcluster()
+    
+    hmp = dm.get_numpy_matrix()
+    fnames = dm.get_fnames()
+    snames = dm.get_snames()
+    fnames_meta = snames.names[1:]
+    #if not args.no_sclustering or not args.no_fclustering ):
+    
+    hmp = cl.get_reordered_matrix( hmp, sclustering = not args.no_sclustering, fclustering = not args.no_fclustering  )
+    if not args.no_sclustering:
+        snames = cl.get_reordered_sample_labels( snames )
+    if not args.no_fclustering:
+        fnames = cl.get_reordered_feature_labels( fnames )
+    else:
+        fnames = fnames[::-1]
+
+    hm = Heatmap( hmp, cl.sdendrogram, cl.fdendrogram, snames, fnames, fnames_meta, args = args )
+    hm.draw()
+
+
+
+
+
+