comparison hclust2/hclust2.py @ 0:cac6247cb1d3 draft

graphlan_import
author george-weingart
date Tue, 26 Aug 2014 14:51:29 -0400
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cac6247cb1d3
1 #!/usr/bin/env python
2
3 import sys
4 import numpy as np
5 import matplotlib.ticker as ticker
6 import scipy.spatial.distance as spd
7 import scipy.cluster.hierarchy as sph
8 from scipy import stats
9 import matplotlib
10 #matplotlib.use('Agg')
11 import pylab
12 import pandas as pd
13 from matplotlib.patches import Rectangle
14 from mpl_toolkits.axes_grid1 import make_axes_locatable
15 import matplotlib.pyplot as plt
16 import matplotlib.gridspec as gridspec
17 import cPickle as pickle
18
19 sys.setrecursionlimit(10000)
20
21 # samples on rows
22
23 class SqrtNorm(matplotlib.colors.Normalize):
24 """
25 Normalize a given value to the 0-1 range on a square root scale
26 """
27 def __call__(self, value, clip=None):
28 if clip is None:
29 clip = self.clip
30
31 result, is_scalar = self.process_value(value)
32
33 result = np.ma.masked_less_equal(result, 0, copy=False)
34
35 self.autoscale_None(result)
36 vmin, vmax = self.vmin, self.vmax
37 if vmin > vmax:
38 raise ValueError("minvalue must be less than or equal to maxvalue")
39 elif vmin <= 0:
40 raise ValueError("values must all be positive")
41 elif vmin == vmax:
42 result.fill(0)
43 else:
44 if clip:
45 mask = np.ma.getmask(result)
46 result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
47 mask=mask)
48 # in-place equivalent of above can be much faster
49 resdat = result.data
50 mask = result.mask
51 if mask is np.ma.nomask:
52 mask = (resdat <= 0)
53 else:
54 mask |= resdat <= 0
55 matplotlib.cbook._putmask(resdat, mask, 1)
56 np.sqrt(resdat, resdat)
57 resdat -= np.sqrt(vmin)
58 resdat /= (np.sqrt(vmax) - np.sqrt(vmin))
59 result = np.ma.array(resdat, mask=mask, copy=False)
60 if is_scalar:
61 result = result[0]
62 return result
63
64 def inverse(self, value):
65 if not self.scaled():
66 raise ValueError("Not invertible until scaled")
67 vmin, vmax = self.vmin, self.vmax
68
69 if matplotlib.cbook.iterable(value):
70 val = np.ma.asarray(value)
71 return vmin * np.ma.power((vmax / vmin), val)
72 else:
73 return vmin * pow((vmax / vmin), value)
74
75 def autoscale(self, A):
76 '''
77 Set *vmin*, *vmax* to min, max of *A*.
78 '''
79 A = np.ma.masked_less_equal(A, 0, copy=False)
80 self.vmin = np.ma.min(A)
81 self.vmax = np.ma.max(A)
82
83 def autoscale_None(self, A):
84 ' autoscale only None-valued vmin or vmax'
85 if self.vmin is not None and self.vmax is not None:
86 return
87 A = np.ma.masked_less_equal(A, 0, copy=False)
88 if self.vmin is None:
89 self.vmin = np.ma.min(A)
90 if self.vmax is None:
91 self.vmax = np.ma.max(A)
92
93 class DataMatrix:
94 datatype = 'data_matrix'
95
96 @staticmethod
97 def input_parameters( parser ):
98 dm_param = parser.add_argument_group('Input data matrix parameters')
99 arg = dm_param.add_argument
100
101 arg( '--sep', type=str, default='\t' )
102 arg( '--out_table', type=str, default=None,
103 help = 'Write processed data matrix to file' )
104 arg( '--fname_row', type=int, default=0,
105 help = "row number containing the names of the features "
106 "[default 0, specify -1 if no names are present in the matrix")
107 arg( '--sname_row', type=int, default=0,
108 help = "column number containing the names of the samples "
109 "[default 0, specify -1 if no names are present in the matrix")
110 arg( '--metadata_rows', type=str, default=None,
111 help = "Row numbers to use as metadata"
112 "[default None, meaning no metadata")
113 arg( '--skip_rows', type=str, default=None,
114 help = "Row numbers to skip (0-indexed, comma separated) from the input file"
115 "[default None, meaning no rows skipped")
116 arg( '--sperc', type=int, default=90,
117 help = "Percentile of sample value distribution for sample selection" )
118 arg( '--fperc', type=int, default=90,
119 help = "Percentile of feature value distribution for sample selection" )
120 arg( '--stop', type=int, default=None,
121 help = "Number of top samples to select (ordering based on percentile specified by --sperc)" )
122 arg( '--ftop', type=int, default=None,
123 help = "Number of top features to select (ordering based on percentile specified by --fperc)" )
124 arg( '--def_na', type=float, default=None,
125 help = "Set the default value for missing values [default None which means no replacement]")
126
127 def __init__( self, input_file, args ):
128 self.args = args
129 self.metadata_rows = []
130 self.metadata_table = None
131 toskip = [int(l) for l in self.args.skip_rows.split(",")] if self.args.skip_rows else []
132 if self.args.metadata_rows:
133 self.metadata_rows = list([int(a) for a in self.args.metadata_rows.split(",")])
134 mdr = self.metadata_rows[::]
135 for t in toskip:
136 for i,m in enumerate(mdr):
137 if t <= m:
138 self.metadata_rows[i] -= 1
139 if self.metadata_rows:
140 header = [self.args.fname_row]+self.metadata_rows if self.args.fname_row > -1 else self.metadata_rows
141 else:
142 header = self.args.fname_row if self.args.fname_row > -1 else None
143 self.table = pd.read_table(
144 input_file, sep = self.args.sep, # skipinitialspace = True,
145 skiprows = sorted(toskip) if isinstance(toskip, list) else toskip,
146 header = sorted(header) if isinstance(header, list) else header,
147 index_col = self.args.sname_row if self.args.sname_row > -1 else None
148 )
149
150 def select( perc, top ):
151 self.table['perc'] = self.table.apply(lambda x: stats.scoreatpercentile(x,perc),axis=1)
152 m = sorted(self.table['perc'])[-top]
153 self.table = self.table[self.table['perc'] >= m ]
154 del self.table['perc']
155
156 if not self.args.def_na is None:
157 self.table = self.table.fillna( self.args.def_na )
158
159 if self.args.ftop:
160 select( self.args.fperc, self.args.ftop )
161
162 if self.args.stop:
163 self.table = self.table.T
164 select( self.args.sperc, self.args.stop )
165 self.table = self.table.T
166
167
168 # add missing values
169
170 def get_numpy_matrix( self ):
171 return np.matrix(self.table)
172
173 #def get_metadata_matrix( self ):
174 # return self.table.columns
175
176 def get_snames( self ):
177 #return list(self.table.index)
178 return self.table.columns
179
180 def get_fnames( self ):
181 #print self.table.columns.names
182 #print self.table.columns
183 return list(self.table.index)
184
185 def get_averages(self, by_row = True) :
186 return self.table.mean(axis = 1 if by_row else 0)
187
188 def save_matrix( self, output_file ):
189 self.table.to_csv( output_file, sep = '\t' )
190
191 class DistMatrix:
192 datatype = 'distance_matrix'
193
194 @staticmethod
195 def input_parameters( parser ):
196 dm_param = parser.add_argument_group('Distance parameters')
197 arg = dm_param.add_argument
198
199 dist_funcs = [ "euclidean","minkowski","cityblock","seuclidean",
200 "sqeuclidean","cosine","correlation","hamming",
201 "jaccard","chebyshev","canberra","braycurtis",
202 "mahalanobis","yule","matching","dice",
203 "kulsinski","rogerstanimoto","russellrao","sokalmichener",
204 "sokalsneath","wminkowski","ward" ]
205
206 arg( '--f_dist_f', type=str, default="correlation",
207 help = "Distance function for features [default correlation]")
208 arg( '--s_dist_f', type=str, default="euclidean",
209 help = "Distance function for sample [default euclidean]")
210 arg( '--load_dist_matrix_f', type=str, default=None,
211 help = "Load the distance matrix to be used for features [default None].")
212 arg( '--load_dist_matrix_s', type=str, default=None,
213 help = "Load the distance matrix to be used for samples [default None].")
214 arg( '--save_dist_matrix_f', type=str, default=None,
215 help = "Save the distance matrix for features to file [default None].")
216 arg( '--save_dist_matrix_s', type=str, default=None,
217 help = "Save the distance matrix for samples to file [default None].")
218
219 def __init__( self, data, args = None ):
220 self.sdf = args.s_dist_f
221 self.fdf = args.f_dist_f
222
223 self.s_cdist_matrix, self.f_cdist_matrix = None, None
224
225 self.numpy_full_matrix = (data if
226 type(data) == np.matrixlib.defmatrix.matrix else None)
227
228 def compute_f_dists( self ):
229 if args.load_dist_matrix_f:
230 with open( args.load_dist_matrix_f ) as inp:
231 self.f_cdist_matrix = pickle.load( inp )
232
233 else:
234 dt = self.numpy_full_matrix
235
236 if self.fdf == "spearman":
237 dt_ranked = np.matrix([stats.rankdata(d) for d in dt])
238 self.f_cdist_matrix = spd.pdist( dt_ranked, "correlation" )
239 return
240
241 if self.fdf == "pearson":
242 self.fdf = 'correlation'
243
244 self.f_cdist_matrix = spd.pdist( dt, self.fdf )
245
246 if args.save_dist_matrix_f:
247 with open( args.save_dist_matrix_f, "wb" ) as outf:
248 pickle.dump( self.f_cdist_matrix, outf )
249
250 def compute_s_dists( self ):
251 if args.load_dist_matrix_s:
252 with open( args.load_dist_matrix_s ) as inp:
253 self.s_cdist_matrix = pickle.load( inp )
254 else:
255 dt = self.numpy_full_matrix.transpose()
256
257 if self.sdf == "spearman":
258 dt_ranked = np.matrix([stats.rankdata(d) for d in dt])
259 self.s_cdist_matrix = spd.pdist( dt_ranked, "correlation" )
260 return
261
262 if self.sdf == "pearson":
263 self.sdf = 'correlation'
264
265 self.s_cdist_matrix = spd.pdist( dt, self.sdf )
266
267 if args.save_dist_matrix_s:
268 with open( args.save_dist_matrix_s, "wb" ) as outf:
269 pickle.dump( self.s_cdist_matrix, outf )
270
271 def get_s_dm( self ):
272 return self.s_cdist_matrix
273
274 def get_f_dm( self ):
275 return self.f_cdist_matrix
276
277 class HClustering:
278 datatype = 'hclustering'
279
280 @staticmethod
281 def input_parameters( parser ):
282 cl_param = parser.add_argument_group('Clustering parameters')
283 arg = cl_param.add_argument
284
285 linkage_method = [ "single","complete","average",
286 "weighted","centroid","median",
287 "ward" ]
288 arg( '--no_fclustering', action='store_true',
289 help = "avoid clustering features" )
290 arg( '--no_sclustering', action='store_true',
291 help = "avoid clustering samples" )
292 arg( '--flinkage', type=str, default="average",
293 help = "Linkage method for feature clustering [default average]")
294 arg( '--slinkage', type=str, default="average",
295 help = "Linkage method for sample clustering [default average]")
296
297 def get_reordered_matrix( self, matrix, sclustering = True, fclustering = True ):
298 if not sclustering and not fclustering:
299 return matrix
300
301 idx1 = self.sdendrogram['leaves'] if sclustering else None # !!!!!!!!!!!
302 idx2 = self.fdendrogram['leaves'][::-1] if fclustering else None
303
304 if sclustering and fclustering:
305 return matrix[idx2,:][:,idx1]
306 if fclustering:
307 return matrix[idx2,:][:]
308 if sclustering: # !!!!!!!!!!!!
309 return matrix[:][:,idx1]
310
311 def get_reordered_sample_labels( self, slabels ):
312 return [slabels[i] for i in self.sdendrogram['leaves']]
313
314 def get_reordered_feature_labels( self, flabels ):
315 return [flabels[i] for i in self.fdendrogram['leaves']]
316
317 def __init__( self, s_dm, f_dm, args = None ):
318 self.s_dm = s_dm
319 self.f_dm = f_dm
320 self.args = args
321 self.sclusters = None
322 self.fclusters = None
323 self.sdendrogram = None
324 self.fdendrogram = None
325
326 def shcluster( self, dendrogram = True ):
327 self.shclusters = sph.linkage( self.s_dm, args.slinkage )
328 if dendrogram:
329 self.sdendrogram = sph.dendrogram( self.shclusters, no_plot=True )
330
331 def fhcluster( self, dendrogram = True ):
332 self.fhclusters = sph.linkage( self.f_dm, args.flinkage )
333 if dendrogram:
334 self.fdendrogram = sph.dendrogram( self.fhclusters, no_plot=True )
335
336 def get_shclusters( self ):
337 return self.shclusters
338
339 def get_fhclusters( self ):
340 return self.fhclusters
341
342 def get_sdendrogram( self ):
343 return self.sdendrogram
344
345 def get_fdendrogram( self ):
346 return self.fdendrogram
347
348
349 class Heatmap:
350 datatype = 'heatmap'
351
352 bbcyr = {'red': ( (0.0, 0.0, 0.0),
353 (0.25, 0.0, 0.0),
354 (0.50, 0.0, 0.0),
355 (0.75, 1.0, 1.0),
356 (1.0, 1.0, 1.0)),
357 'green': ( (0.0, 0.0, 0.0),
358 (0.25, 0.0, 0.0),
359 (0.50, 1.0, 1.0),
360 (0.75, 1.0, 1.0),
361 (1.0, 0.0, 1.0)),
362 'blue': ( (0.0, 0.0, 0.0),
363 (0.25, 1.0, 1.0),
364 (0.50, 1.0, 1.0),
365 (0.75, 0.0, 0.0),
366 (1.0, 0.0, 1.0))}
367
368 bbcry = {'red': ( (0.0, 0.0, 0.0),
369 (0.25, 0.0, 0.0),
370 (0.50, 0.0, 0.0),
371 (0.75, 1.0, 1.0),
372 (1.0, 1.0, 1.0)),
373 'green': ( (0.0, 0.0, 0.0),
374 (0.25, 0.0, 0.0),
375 (0.50, 1.0, 1.0),
376 (0.75, 0.0, 0.0),
377 (1.0, 1.0, 1.0)),
378 'blue': ( (0.0, 0.0, 0.0),
379 (0.25, 1.0, 1.0),
380 (0.50, 1.0, 1.0),
381 (0.75, 0.0, 0.0),
382 (1.0, 0.0, 1.0))}
383
384 bcry = {'red': ( (0.0, 0.0, 0.0),
385 (0.33, 0.0, 0.0),
386 (0.66, 1.0, 1.0),
387 (1.0, 1.0, 1.0)),
388 'green': ( (0.0, 0.0, 0.0),
389 (0.33, 1.0, 1.0),
390 (0.66, 0.0, 0.0),
391 (1.0, 1.0, 1.0)),
392 'blue': ( (0.0, 1.0, 1.0),
393 (0.33, 1.0, 1.0),
394 (0.66, 0.0, 0.0),
395 (1.0, 0.0, 1.0))}
396
397
398 my_colormaps = [ ('bbcyr',bbcyr),
399 ('bbcry',bbcry),
400 ('bcry',bcry)]
401
402 dcols = ['#ca0000','#0087ff','#00ba1d','#cf00ff','#00dbe2','#ffaf00','#0017f4','#006012','#e175ff','#877878','#050505','#b5cf00','#ff8a8a','#aa6400','#50008a','#00ff58']
403
404
405 @staticmethod
406 def input_parameters( parser ):
407 hm_param = parser.add_argument_group('Heatmap options')
408 arg = hm_param.add_argument
409
410 arg( '--dpi', type=int, default=150,
411 help = "Image resolution in dpi [default 150]")
412 arg( '-l', '--log_scale', action='store_true',
413 help = "Log scale" )
414 arg( '-s', '--sqrt_scale', action='store_true',
415 help = "Square root scale" )
416 arg( '--no_slabels', action='store_true',
417 help = "Do not show sample labels" )
418 arg( '--minv', type=float, default=None,
419 help = "Minimum value to display in the color map [default None meaning automatic]" )
420 arg( '--maxv', type=float, default=None,
421 help = "Maximum value to display in the color map [default None meaning automatic]" )
422 arg( '--no_flabels', action='store_true',
423 help = "Do not show feature labels" )
424 arg( '--max_slabel_len', type=int, default=25,
425 help = "Max number of chars to report for sample labels [default 15]" )
426 arg( '--max_flabel_len', type=int, default=25,
427 help = "Max number of chars to report for feature labels [default 15]" )
428 arg( '--flabel_size', type=int, default=10,
429 help = "Feature label font size [default 10]" )
430 arg( '--slabel_size', type=int, default=10,
431 help = "Sample label font size [default 10]" )
432 arg( '--fdend_width', type=float, default=1.0,
433 help = "Width of the feature dendrogram [default 1 meaning 100%% of default heatmap width]")
434 arg( '--sdend_height', type=float, default=1.0,
435 help = "Height of the sample dendrogram [default 1 meaning 100%% of default heatmap height]")
436 arg( '--metadata_height', type=float, default=.05,
437 help = "Height of the metadata panel [default 0.05 meaning 5%% of default heatmap height]")
438 arg( '--metadata_separation', type=float, default=.01,
439 help = "Distance between the metadata and data panels. [default 0.001 meaning 0.1%% of default heatmap height]")
440 arg( '--image_size', type=float, default=8,
441 help = "Size of the largest between width and eight size for the image in inches [default 8]")
442 arg( '--cell_aspect_ratio', type=float, default=1.0,
443 help = "Aspect ratio between width and height for the cells of the heatmap [default 1.0]")
444 col_maps = ['Accent', 'Blues', 'BrBG', 'BuGn', 'BuPu', 'Dark2', 'GnBu',
445 'Greens', 'Greys', 'OrRd', 'Oranges', 'PRGn', 'Paired',
446 'Pastel1', 'Pastel2', 'PiYG', 'PuBu', 'PuBuGn', 'PuOr',
447 'PuRd', 'Purples', 'RdBu', 'RdGy', 'RdPu', 'RdYlBu', 'RdYlGn',
448 'Reds', 'Set1', 'Set2', 'Set3', 'Spectral', 'YlGn', 'YlGnBu',
449 'YlOrBr', 'YlOrRd', 'afmhot', 'autumn', 'binary', 'bone',
450 'brg', 'bwr', 'cool', 'copper', 'flag', 'gist_earth',
451 'gist_gray', 'gist_heat', 'gist_ncar', 'gist_rainbow',
452 'gist_stern', 'gist_yarg', 'gnuplot', 'gnuplot2', 'gray',
453 'hot', 'hsv', 'jet', 'ocean', 'pink', 'prism', 'rainbow',
454 'seismic', 'spectral', 'spring', 'summer', 'terrain', 'winter'] + [n for n,c in Heatmap.my_colormaps]
455 for n,c in Heatmap.my_colormaps:
456 my_cmap = matplotlib.colors.LinearSegmentedColormap(n,c,256)
457 pylab.register_cmap(name=n,cmap=my_cmap)
458 arg( '-c','--colormap', type=str, choices = col_maps, default = 'bbcry' )
459 arg( '--bottom_c', type=str, default = None,
460 help = "Color to use for cells below the minimum value of the scale [default None meaning bottom color of the scale]")
461 arg( '--top_c', type=str, default = None,
462 help = "Color to use for cells below the maximum value of the scale [default None meaning bottom color of the scale]")
463 arg( '--nan_c', type=str, default = None,
464 help = "Color to use for nan cells [default None]")
465
466
467
468 """
469 arg( '--', type=str, default="average",
470 help = "Linkage method for feature clustering [default average]")
471 arg( '--slinkage', type=str, default="average",
472 help = "Linkage method for sample clustering [default average]")
473 """
474
475 def __init__( self, numpy_matrix, sdendrogram, fdendrogram, snames, fnames, fnames_meta, args = None ):
476 self.numpy_matrix = numpy_matrix
477 self.sdendrogram = sdendrogram
478 self.fdendrogram = fdendrogram
479 self.snames = snames
480 self.fnames = fnames
481 self.fnames_meta = fnames_meta
482 self.ns,self.nf = self.numpy_matrix.shape
483 self.args = args
484
485 def make_legend( self, dmap, titles, out_fn ):
486 figlegend = plt.figure(figsize=(1+3*len(titles),2), frameon = False)
487
488 gs = gridspec.GridSpec( 1, len(dmap), wspace = 2.0 )
489
490 for i,(d,title) in enumerate(zip(dmap,titles)):
491 legax = plt.subplot(gs[i],frameon = False)
492 for k,v in sorted(d.items(),key=lambda x:x[1]):
493 rect = Rectangle( [0.0, 0.0], 0.0, 0.0,
494 facecolor = self.dcols[v%len(self.dcols)],
495 label = k,
496 edgecolor='b', lw = 0.0)
497
498 legax.add_patch(rect)
499 #remove_splines( legax )
500 legax.set_xticks([])
501 legax.set_yticks([])
502 legax.legend( loc = 2, frameon = False, title = title)
503 """
504 ncol = legend_ncol, bbox_to_anchor=(1.01, 3.),
505 borderpad = 0.0, labelspacing = 0.0,
506 handlelength = 0.5, handletextpad = 0.3,
507 borderaxespad = 0.0, columnspacing = 0.3,
508 prop = {'size':fontsize}, frameon = False)
509 """
510 if out_fn:
511 figlegend.savefig(out_fn, bbox_inches='tight')
512
513 def draw( self ):
514
515 rat = float(self.ns)/self.nf
516 rat *= self.args.cell_aspect_ratio
517 x,y = (self.args.image_size,rat*self.args.image_size) if rat < 1 else (self.args.image_size/rat,self.args.image_size)
518 fig = plt.figure( figsize=(x,y), facecolor = 'w' )
519
520 cm = pylab.get_cmap(self.args.colormap)
521 bottom_col = [ cm._segmentdata['red'][0][1],
522 cm._segmentdata['green'][0][1],
523 cm._segmentdata['blue'][0][1] ]
524 if self.args.bottom_c:
525 bottom_col = self.args.bottom_c
526 cm.set_under( bottom_col )
527 top_col = [ cm._segmentdata['red'][-1][1],
528 cm._segmentdata['green'][-1][1],
529 cm._segmentdata['blue'][-1][1] ]
530 if self.args.top_c:
531 top_col = self.args.top_c
532 cm.set_over( top_col )
533
534 if self.args.nan_c:
535 cm.set_bad( self.args.nan_c )
536
537 def make_ticklabels_invisible(ax):
538 for tl in ax.get_xticklabels() + ax.get_yticklabels():
539 tl.set_visible(False)
540 ax.set_xticks([])
541 ax.set_yticks([])
542
543 def remove_splines( ax ):
544 for v in ['right','left','top','bottom']:
545 ax.spines[v].set_color('none')
546
547 def shrink_labels( labels, n ):
548 shrink = lambda x: x[:n/2]+" [...] "+x[-n/2:]
549 return [(shrink(str(l)) if len(str(l)) > n else l) for l in labels]
550
551
552 #gs = gridspec.GridSpec( 4, 2,
553 # width_ratios=[1.0-fr_ns,fr_ns],
554 # height_ratios=[.03,0.03,1.0-fr_nf,fr_nf],
555 # wspace = 0.0, hspace = 0.0 )
556
557 fr_ns = float(self.ns)/max([self.ns,self.nf])
558 fr_nf = float(self.nf)/max([self.ns,self.nf])
559
560 buf_space = 0.05
561 minv = min( [buf_space*8, 8*rat*buf_space] )
562 if minv < 0.05:
563 buf_space /= minv/0.05
564 metadata_height = self.args.metadata_height if type(snames[0]) is tuple and len(snames[0]) > 1 else 0.000001
565 gs = gridspec.GridSpec( 6, 4,
566 width_ratios=[ buf_space, buf_space*2, .08*self.args.fdend_width,0.9],
567 height_ratios=[ buf_space, buf_space*2, .08*self.args.sdend_height, metadata_height, self.args.metadata_separation, 0.9],
568 wspace = 0.0, hspace = 0.0 )
569
570 ax_hm = plt.subplot(gs[23], axisbg = bottom_col )
571 ax_metadata = plt.subplot(gs[15], axisbg = bottom_col )
572 ax_hm_y2 = ax_hm.twinx()
573
574 norm_f = matplotlib.colors.Normalize
575 if self.args.log_scale:
576 norm_f = matplotlib.colors.LogNorm
577 elif self.args.sqrt_scale:
578 norm_f = SqrtNorm
579 minv, maxv = 0.0, None
580
581 maps, values, ndv = [], [], 0
582 if type(snames[0]) is tuple and len(snames[0]) > 1:
583 metadata = zip(*[list(s[1:]) for s in snames])
584 for m in metadata:
585 mmap = dict([(v[1],ndv+v[0]) for v in enumerate(list(set(m)))])
586 values.append([mmap[v] for v in m])
587 ndv += len(mmap)
588 maps.append(mmap)
589 dcols = []
590 mdmat = np.matrix(values)
591 while len(dcols) < ndv:
592 dcols += self.dcols
593 cmap = matplotlib.colors.ListedColormap(dcols[:ndv])
594 bounds = [float(f)-0.5 for f in range(ndv+1)]
595 imm = ax_metadata.imshow( mdmat, #origin='lower',
596 interpolation = 'nearest',
597 aspect='auto',
598 extent = [0, self.nf, 0, self.ns],
599 cmap=cmap,
600 vmin=bounds[0],
601 vmax=bounds[-1],
602 )
603 remove_splines( ax_metadata )
604 ax_metadata_y2 = ax_metadata.twinx()
605 ax_metadata_y2.set_ylim(0,len(self.fnames_meta))
606 ax_metadata.set_yticks([])
607 ax_metadata_y2.set_ylim(0,len(self.fnames_meta))
608 ax_metadata_y2.tick_params(length=0)
609 ax_metadata_y2.set_yticks(np.arange(len(self.fnames_meta))+0.5)
610 ax_metadata_y2.set_yticklabels(self.fnames_meta[::-1], va='center',size=self.args.flabel_size)
611 else:
612 ax_metadata.set_yticks([])
613
614 ax_metadata.set_xticks([])
615
616 im = ax_hm.imshow( self.numpy_matrix, #origin='lower',
617 interpolation = 'nearest', aspect='auto',
618 extent = [0, self.nf, 0, self.ns],
619 cmap=cm,
620 vmin=self.args.minv,
621 vmax=self.args.maxv,
622 norm = norm_f( vmin=minv if minv > 0.0 else None, vmax=maxv)
623 )
624
625 #ax_hm.set_ylim([0,800])
626 ax_hm.set_xticks(np.arange(len(list(snames)))+0.5)
627 if not self.args.no_slabels:
628 snames_short = shrink_labels( list([s[0] for s in snames]) if type(snames[0]) is tuple else snames, self.args.max_slabel_len )
629 ax_hm.set_xticklabels(snames_short,rotation=90,va='top',ha='center',size=self.args.slabel_size)
630 else:
631 ax_hm.set_xticklabels([])
632 ax_hm_y2.set_ylim([0,self.ns])
633 ax_hm_y2.set_yticks(np.arange(len(fnames))+0.5)
634 if not self.args.no_flabels:
635 fnames_short = shrink_labels( fnames, self.args.max_flabel_len )
636 ax_hm_y2.set_yticklabels(fnames_short,va='center',size=self.args.flabel_size)
637 else:
638 ax_hm_y2.set_yticklabels( [] )
639 ax_hm.set_yticks([])
640 remove_splines( ax_hm )
641 ax_hm.tick_params(length=0)
642 ax_hm_y2.tick_params(length=0)
643 #ax_hm.set_xlim([0,self.ns])
644 ax_cm = plt.subplot(gs[3], axisbg = 'r', frameon = False)
645 #fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing = 'proportional', format = ticker.LogFormatterMathtext() )
646 fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing='proportional' if self.args.sqrt_scale else 'uniform' ) # , format = ticker.LogFormatterMathtext() )
647
648 if not self.args.no_sclustering:
649 ax_den_top = plt.subplot(gs[11], axisbg = 'r', frameon = False)
650 sph._plot_dendrogram( self.sdendrogram['icoord'], self.sdendrogram['dcoord'], self.sdendrogram['ivl'],
651 self.ns + 1, self.nf + 1, 1, 'top', no_labels=True,
652 color_list=self.sdendrogram['color_list'] )
653 ymax = max([max(a) for a in self.sdendrogram['dcoord']])
654 ax_den_top.set_ylim([0,ymax])
655 make_ticklabels_invisible( ax_den_top )
656 if not self.args.no_fclustering:
657 ax_den_right = plt.subplot(gs[22], axisbg = 'b', frameon = False)
658 sph._plot_dendrogram( self.fdendrogram['icoord'], self.fdendrogram['dcoord'], self.fdendrogram['ivl'],
659 self.ns + 1, self.nf + 1, 1, 'right', no_labels=True,
660 color_list=self.fdendrogram['color_list'] )
661 xmax = max([max(a) for a in self.fdendrogram['dcoord']])
662 ax_den_right.set_xlim([xmax,0])
663 make_ticklabels_invisible( ax_den_right )
664
665
666 if not self.args.out:
667 plt.show( )
668 else:
669 fig.savefig( self.args.out, bbox_inches='tight', dpi = self.args.dpi )
670 if maps:
671 self.make_legend( maps, fnames_meta, self.args.legend_file )
672
673
674
675 class ReadCmd:
676
677 def __init__( self ):
678 import argparse as ap
679 import textwrap
680
681 p = ap.ArgumentParser( description= "TBA" )
682 arg = p.add_argument
683
684 arg( '-i', '--inp', '--in', metavar='INPUT_FILE', type=str, nargs='?', default=sys.stdin,
685 help= "The input matrix" )
686 arg( '-o', '--out', metavar='OUTPUT_FILE', type=str, nargs='?', default=None,
687 help= "The output image file [image on screen of not specified]" )
688 arg( '--legend_file', metavar='LEGEND_FILE', type=str, nargs='?', default=None,
689 help= "The output file for the legend of the provided metadata" )
690
691 input_types = [DataMatrix.datatype,DistMatrix.datatype]
692 arg( '-t', '--input_type', metavar='INPUT_TYPE', type=str, choices = input_types,
693 default='data_matrix',
694 help= "The input type can be a data matrix or distance matrix [default data_matrix]" )
695
696 DataMatrix.input_parameters( p )
697 DistMatrix.input_parameters( p )
698 HClustering.input_parameters( p )
699 Heatmap.input_parameters( p )
700
701 self.args = p.parse_args()
702
703 def check_consistency( self ):
704 pass
705
706 def get_args( self ):
707 return self.args
708
709 if __name__ == '__main__':
710
711 read = ReadCmd( )
712 read.check_consistency()
713 args = read.get_args()
714
715 if args.input_type == DataMatrix.datatype:
716 dm = DataMatrix( args.inp, args )
717 if args.out_table:
718 dm.save_matrix( args.out_table )
719
720 distm = DistMatrix( dm.get_numpy_matrix(), args = args )
721 if not args.no_sclustering:
722 distm.compute_s_dists()
723 if not args.no_fclustering:
724 distm.compute_f_dists()
725 elif args.input_type == DataMatrix.datatype:
726 # distm = read...
727 pass
728 else:
729 pass
730
731 cl = HClustering( distm.get_s_dm(), distm.get_f_dm(), args = args )
732 if not args.no_sclustering:
733 cl.shcluster()
734 if not args.no_fclustering:
735 cl.fhcluster()
736
737 hmp = dm.get_numpy_matrix()
738 fnames = dm.get_fnames()
739 snames = dm.get_snames()
740 fnames_meta = snames.names[1:]
741 #if not args.no_sclustering or not args.no_fclustering ):
742
743 hmp = cl.get_reordered_matrix( hmp, sclustering = not args.no_sclustering, fclustering = not args.no_fclustering )
744 if not args.no_sclustering:
745 snames = cl.get_reordered_sample_labels( snames )
746 if not args.no_fclustering:
747 fnames = cl.get_reordered_feature_labels( fnames )
748 else:
749 fnames = fnames[::-1]
750
751 hm = Heatmap( hmp, cl.sdendrogram, cl.fdendrogram, snames, fnames, fnames_meta, args = args )
752 hm.draw()
753
754
755
756
757
758