Mercurial > repos > george-weingart > graphlan_import
comparison hclust2/hclust2.py @ 0:cac6247cb1d3 draft
graphlan_import
author | george-weingart |
---|---|
date | Tue, 26 Aug 2014 14:51:29 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:cac6247cb1d3 |
---|---|
1 #!/usr/bin/env python | |
2 | |
3 import sys | |
4 import numpy as np | |
5 import matplotlib.ticker as ticker | |
6 import scipy.spatial.distance as spd | |
7 import scipy.cluster.hierarchy as sph | |
8 from scipy import stats | |
9 import matplotlib | |
10 #matplotlib.use('Agg') | |
11 import pylab | |
12 import pandas as pd | |
13 from matplotlib.patches import Rectangle | |
14 from mpl_toolkits.axes_grid1 import make_axes_locatable | |
15 import matplotlib.pyplot as plt | |
16 import matplotlib.gridspec as gridspec | |
17 import cPickle as pickle | |
18 | |
19 sys.setrecursionlimit(10000) | |
20 | |
21 # samples on rows | |
22 | |
23 class SqrtNorm(matplotlib.colors.Normalize): | |
24 """ | |
25 Normalize a given value to the 0-1 range on a square root scale | |
26 """ | |
27 def __call__(self, value, clip=None): | |
28 if clip is None: | |
29 clip = self.clip | |
30 | |
31 result, is_scalar = self.process_value(value) | |
32 | |
33 result = np.ma.masked_less_equal(result, 0, copy=False) | |
34 | |
35 self.autoscale_None(result) | |
36 vmin, vmax = self.vmin, self.vmax | |
37 if vmin > vmax: | |
38 raise ValueError("minvalue must be less than or equal to maxvalue") | |
39 elif vmin <= 0: | |
40 raise ValueError("values must all be positive") | |
41 elif vmin == vmax: | |
42 result.fill(0) | |
43 else: | |
44 if clip: | |
45 mask = np.ma.getmask(result) | |
46 result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax), | |
47 mask=mask) | |
48 # in-place equivalent of above can be much faster | |
49 resdat = result.data | |
50 mask = result.mask | |
51 if mask is np.ma.nomask: | |
52 mask = (resdat <= 0) | |
53 else: | |
54 mask |= resdat <= 0 | |
55 matplotlib.cbook._putmask(resdat, mask, 1) | |
56 np.sqrt(resdat, resdat) | |
57 resdat -= np.sqrt(vmin) | |
58 resdat /= (np.sqrt(vmax) - np.sqrt(vmin)) | |
59 result = np.ma.array(resdat, mask=mask, copy=False) | |
60 if is_scalar: | |
61 result = result[0] | |
62 return result | |
63 | |
64 def inverse(self, value): | |
65 if not self.scaled(): | |
66 raise ValueError("Not invertible until scaled") | |
67 vmin, vmax = self.vmin, self.vmax | |
68 | |
69 if matplotlib.cbook.iterable(value): | |
70 val = np.ma.asarray(value) | |
71 return vmin * np.ma.power((vmax / vmin), val) | |
72 else: | |
73 return vmin * pow((vmax / vmin), value) | |
74 | |
75 def autoscale(self, A): | |
76 ''' | |
77 Set *vmin*, *vmax* to min, max of *A*. | |
78 ''' | |
79 A = np.ma.masked_less_equal(A, 0, copy=False) | |
80 self.vmin = np.ma.min(A) | |
81 self.vmax = np.ma.max(A) | |
82 | |
83 def autoscale_None(self, A): | |
84 ' autoscale only None-valued vmin or vmax' | |
85 if self.vmin is not None and self.vmax is not None: | |
86 return | |
87 A = np.ma.masked_less_equal(A, 0, copy=False) | |
88 if self.vmin is None: | |
89 self.vmin = np.ma.min(A) | |
90 if self.vmax is None: | |
91 self.vmax = np.ma.max(A) | |
92 | |
93 class DataMatrix: | |
94 datatype = 'data_matrix' | |
95 | |
96 @staticmethod | |
97 def input_parameters( parser ): | |
98 dm_param = parser.add_argument_group('Input data matrix parameters') | |
99 arg = dm_param.add_argument | |
100 | |
101 arg( '--sep', type=str, default='\t' ) | |
102 arg( '--out_table', type=str, default=None, | |
103 help = 'Write processed data matrix to file' ) | |
104 arg( '--fname_row', type=int, default=0, | |
105 help = "row number containing the names of the features " | |
106 "[default 0, specify -1 if no names are present in the matrix") | |
107 arg( '--sname_row', type=int, default=0, | |
108 help = "column number containing the names of the samples " | |
109 "[default 0, specify -1 if no names are present in the matrix") | |
110 arg( '--metadata_rows', type=str, default=None, | |
111 help = "Row numbers to use as metadata" | |
112 "[default None, meaning no metadata") | |
113 arg( '--skip_rows', type=str, default=None, | |
114 help = "Row numbers to skip (0-indexed, comma separated) from the input file" | |
115 "[default None, meaning no rows skipped") | |
116 arg( '--sperc', type=int, default=90, | |
117 help = "Percentile of sample value distribution for sample selection" ) | |
118 arg( '--fperc', type=int, default=90, | |
119 help = "Percentile of feature value distribution for sample selection" ) | |
120 arg( '--stop', type=int, default=None, | |
121 help = "Number of top samples to select (ordering based on percentile specified by --sperc)" ) | |
122 arg( '--ftop', type=int, default=None, | |
123 help = "Number of top features to select (ordering based on percentile specified by --fperc)" ) | |
124 arg( '--def_na', type=float, default=None, | |
125 help = "Set the default value for missing values [default None which means no replacement]") | |
126 | |
127 def __init__( self, input_file, args ): | |
128 self.args = args | |
129 self.metadata_rows = [] | |
130 self.metadata_table = None | |
131 toskip = [int(l) for l in self.args.skip_rows.split(",")] if self.args.skip_rows else [] | |
132 if self.args.metadata_rows: | |
133 self.metadata_rows = list([int(a) for a in self.args.metadata_rows.split(",")]) | |
134 mdr = self.metadata_rows[::] | |
135 for t in toskip: | |
136 for i,m in enumerate(mdr): | |
137 if t <= m: | |
138 self.metadata_rows[i] -= 1 | |
139 if self.metadata_rows: | |
140 header = [self.args.fname_row]+self.metadata_rows if self.args.fname_row > -1 else self.metadata_rows | |
141 else: | |
142 header = self.args.fname_row if self.args.fname_row > -1 else None | |
143 self.table = pd.read_table( | |
144 input_file, sep = self.args.sep, # skipinitialspace = True, | |
145 skiprows = sorted(toskip) if isinstance(toskip, list) else toskip, | |
146 header = sorted(header) if isinstance(header, list) else header, | |
147 index_col = self.args.sname_row if self.args.sname_row > -1 else None | |
148 ) | |
149 | |
150 def select( perc, top ): | |
151 self.table['perc'] = self.table.apply(lambda x: stats.scoreatpercentile(x,perc),axis=1) | |
152 m = sorted(self.table['perc'])[-top] | |
153 self.table = self.table[self.table['perc'] >= m ] | |
154 del self.table['perc'] | |
155 | |
156 if not self.args.def_na is None: | |
157 self.table = self.table.fillna( self.args.def_na ) | |
158 | |
159 if self.args.ftop: | |
160 select( self.args.fperc, self.args.ftop ) | |
161 | |
162 if self.args.stop: | |
163 self.table = self.table.T | |
164 select( self.args.sperc, self.args.stop ) | |
165 self.table = self.table.T | |
166 | |
167 | |
168 # add missing values | |
169 | |
170 def get_numpy_matrix( self ): | |
171 return np.matrix(self.table) | |
172 | |
173 #def get_metadata_matrix( self ): | |
174 # return self.table.columns | |
175 | |
176 def get_snames( self ): | |
177 #return list(self.table.index) | |
178 return self.table.columns | |
179 | |
180 def get_fnames( self ): | |
181 #print self.table.columns.names | |
182 #print self.table.columns | |
183 return list(self.table.index) | |
184 | |
185 def get_averages(self, by_row = True) : | |
186 return self.table.mean(axis = 1 if by_row else 0) | |
187 | |
188 def save_matrix( self, output_file ): | |
189 self.table.to_csv( output_file, sep = '\t' ) | |
190 | |
191 class DistMatrix: | |
192 datatype = 'distance_matrix' | |
193 | |
194 @staticmethod | |
195 def input_parameters( parser ): | |
196 dm_param = parser.add_argument_group('Distance parameters') | |
197 arg = dm_param.add_argument | |
198 | |
199 dist_funcs = [ "euclidean","minkowski","cityblock","seuclidean", | |
200 "sqeuclidean","cosine","correlation","hamming", | |
201 "jaccard","chebyshev","canberra","braycurtis", | |
202 "mahalanobis","yule","matching","dice", | |
203 "kulsinski","rogerstanimoto","russellrao","sokalmichener", | |
204 "sokalsneath","wminkowski","ward" ] | |
205 | |
206 arg( '--f_dist_f', type=str, default="correlation", | |
207 help = "Distance function for features [default correlation]") | |
208 arg( '--s_dist_f', type=str, default="euclidean", | |
209 help = "Distance function for sample [default euclidean]") | |
210 arg( '--load_dist_matrix_f', type=str, default=None, | |
211 help = "Load the distance matrix to be used for features [default None].") | |
212 arg( '--load_dist_matrix_s', type=str, default=None, | |
213 help = "Load the distance matrix to be used for samples [default None].") | |
214 arg( '--save_dist_matrix_f', type=str, default=None, | |
215 help = "Save the distance matrix for features to file [default None].") | |
216 arg( '--save_dist_matrix_s', type=str, default=None, | |
217 help = "Save the distance matrix for samples to file [default None].") | |
218 | |
219 def __init__( self, data, args = None ): | |
220 self.sdf = args.s_dist_f | |
221 self.fdf = args.f_dist_f | |
222 | |
223 self.s_cdist_matrix, self.f_cdist_matrix = None, None | |
224 | |
225 self.numpy_full_matrix = (data if | |
226 type(data) == np.matrixlib.defmatrix.matrix else None) | |
227 | |
228 def compute_f_dists( self ): | |
229 if args.load_dist_matrix_f: | |
230 with open( args.load_dist_matrix_f ) as inp: | |
231 self.f_cdist_matrix = pickle.load( inp ) | |
232 | |
233 else: | |
234 dt = self.numpy_full_matrix | |
235 | |
236 if self.fdf == "spearman": | |
237 dt_ranked = np.matrix([stats.rankdata(d) for d in dt]) | |
238 self.f_cdist_matrix = spd.pdist( dt_ranked, "correlation" ) | |
239 return | |
240 | |
241 if self.fdf == "pearson": | |
242 self.fdf = 'correlation' | |
243 | |
244 self.f_cdist_matrix = spd.pdist( dt, self.fdf ) | |
245 | |
246 if args.save_dist_matrix_f: | |
247 with open( args.save_dist_matrix_f, "wb" ) as outf: | |
248 pickle.dump( self.f_cdist_matrix, outf ) | |
249 | |
250 def compute_s_dists( self ): | |
251 if args.load_dist_matrix_s: | |
252 with open( args.load_dist_matrix_s ) as inp: | |
253 self.s_cdist_matrix = pickle.load( inp ) | |
254 else: | |
255 dt = self.numpy_full_matrix.transpose() | |
256 | |
257 if self.sdf == "spearman": | |
258 dt_ranked = np.matrix([stats.rankdata(d) for d in dt]) | |
259 self.s_cdist_matrix = spd.pdist( dt_ranked, "correlation" ) | |
260 return | |
261 | |
262 if self.sdf == "pearson": | |
263 self.sdf = 'correlation' | |
264 | |
265 self.s_cdist_matrix = spd.pdist( dt, self.sdf ) | |
266 | |
267 if args.save_dist_matrix_s: | |
268 with open( args.save_dist_matrix_s, "wb" ) as outf: | |
269 pickle.dump( self.s_cdist_matrix, outf ) | |
270 | |
271 def get_s_dm( self ): | |
272 return self.s_cdist_matrix | |
273 | |
274 def get_f_dm( self ): | |
275 return self.f_cdist_matrix | |
276 | |
277 class HClustering: | |
278 datatype = 'hclustering' | |
279 | |
280 @staticmethod | |
281 def input_parameters( parser ): | |
282 cl_param = parser.add_argument_group('Clustering parameters') | |
283 arg = cl_param.add_argument | |
284 | |
285 linkage_method = [ "single","complete","average", | |
286 "weighted","centroid","median", | |
287 "ward" ] | |
288 arg( '--no_fclustering', action='store_true', | |
289 help = "avoid clustering features" ) | |
290 arg( '--no_sclustering', action='store_true', | |
291 help = "avoid clustering samples" ) | |
292 arg( '--flinkage', type=str, default="average", | |
293 help = "Linkage method for feature clustering [default average]") | |
294 arg( '--slinkage', type=str, default="average", | |
295 help = "Linkage method for sample clustering [default average]") | |
296 | |
297 def get_reordered_matrix( self, matrix, sclustering = True, fclustering = True ): | |
298 if not sclustering and not fclustering: | |
299 return matrix | |
300 | |
301 idx1 = self.sdendrogram['leaves'] if sclustering else None # !!!!!!!!!!! | |
302 idx2 = self.fdendrogram['leaves'][::-1] if fclustering else None | |
303 | |
304 if sclustering and fclustering: | |
305 return matrix[idx2,:][:,idx1] | |
306 if fclustering: | |
307 return matrix[idx2,:][:] | |
308 if sclustering: # !!!!!!!!!!!! | |
309 return matrix[:][:,idx1] | |
310 | |
311 def get_reordered_sample_labels( self, slabels ): | |
312 return [slabels[i] for i in self.sdendrogram['leaves']] | |
313 | |
314 def get_reordered_feature_labels( self, flabels ): | |
315 return [flabels[i] for i in self.fdendrogram['leaves']] | |
316 | |
317 def __init__( self, s_dm, f_dm, args = None ): | |
318 self.s_dm = s_dm | |
319 self.f_dm = f_dm | |
320 self.args = args | |
321 self.sclusters = None | |
322 self.fclusters = None | |
323 self.sdendrogram = None | |
324 self.fdendrogram = None | |
325 | |
326 def shcluster( self, dendrogram = True ): | |
327 self.shclusters = sph.linkage( self.s_dm, args.slinkage ) | |
328 if dendrogram: | |
329 self.sdendrogram = sph.dendrogram( self.shclusters, no_plot=True ) | |
330 | |
331 def fhcluster( self, dendrogram = True ): | |
332 self.fhclusters = sph.linkage( self.f_dm, args.flinkage ) | |
333 if dendrogram: | |
334 self.fdendrogram = sph.dendrogram( self.fhclusters, no_plot=True ) | |
335 | |
336 def get_shclusters( self ): | |
337 return self.shclusters | |
338 | |
339 def get_fhclusters( self ): | |
340 return self.fhclusters | |
341 | |
342 def get_sdendrogram( self ): | |
343 return self.sdendrogram | |
344 | |
345 def get_fdendrogram( self ): | |
346 return self.fdendrogram | |
347 | |
348 | |
349 class Heatmap: | |
350 datatype = 'heatmap' | |
351 | |
352 bbcyr = {'red': ( (0.0, 0.0, 0.0), | |
353 (0.25, 0.0, 0.0), | |
354 (0.50, 0.0, 0.0), | |
355 (0.75, 1.0, 1.0), | |
356 (1.0, 1.0, 1.0)), | |
357 'green': ( (0.0, 0.0, 0.0), | |
358 (0.25, 0.0, 0.0), | |
359 (0.50, 1.0, 1.0), | |
360 (0.75, 1.0, 1.0), | |
361 (1.0, 0.0, 1.0)), | |
362 'blue': ( (0.0, 0.0, 0.0), | |
363 (0.25, 1.0, 1.0), | |
364 (0.50, 1.0, 1.0), | |
365 (0.75, 0.0, 0.0), | |
366 (1.0, 0.0, 1.0))} | |
367 | |
368 bbcry = {'red': ( (0.0, 0.0, 0.0), | |
369 (0.25, 0.0, 0.0), | |
370 (0.50, 0.0, 0.0), | |
371 (0.75, 1.0, 1.0), | |
372 (1.0, 1.0, 1.0)), | |
373 'green': ( (0.0, 0.0, 0.0), | |
374 (0.25, 0.0, 0.0), | |
375 (0.50, 1.0, 1.0), | |
376 (0.75, 0.0, 0.0), | |
377 (1.0, 1.0, 1.0)), | |
378 'blue': ( (0.0, 0.0, 0.0), | |
379 (0.25, 1.0, 1.0), | |
380 (0.50, 1.0, 1.0), | |
381 (0.75, 0.0, 0.0), | |
382 (1.0, 0.0, 1.0))} | |
383 | |
384 bcry = {'red': ( (0.0, 0.0, 0.0), | |
385 (0.33, 0.0, 0.0), | |
386 (0.66, 1.0, 1.0), | |
387 (1.0, 1.0, 1.0)), | |
388 'green': ( (0.0, 0.0, 0.0), | |
389 (0.33, 1.0, 1.0), | |
390 (0.66, 0.0, 0.0), | |
391 (1.0, 1.0, 1.0)), | |
392 'blue': ( (0.0, 1.0, 1.0), | |
393 (0.33, 1.0, 1.0), | |
394 (0.66, 0.0, 0.0), | |
395 (1.0, 0.0, 1.0))} | |
396 | |
397 | |
398 my_colormaps = [ ('bbcyr',bbcyr), | |
399 ('bbcry',bbcry), | |
400 ('bcry',bcry)] | |
401 | |
402 dcols = ['#ca0000','#0087ff','#00ba1d','#cf00ff','#00dbe2','#ffaf00','#0017f4','#006012','#e175ff','#877878','#050505','#b5cf00','#ff8a8a','#aa6400','#50008a','#00ff58'] | |
403 | |
404 | |
405 @staticmethod | |
406 def input_parameters( parser ): | |
407 hm_param = parser.add_argument_group('Heatmap options') | |
408 arg = hm_param.add_argument | |
409 | |
410 arg( '--dpi', type=int, default=150, | |
411 help = "Image resolution in dpi [default 150]") | |
412 arg( '-l', '--log_scale', action='store_true', | |
413 help = "Log scale" ) | |
414 arg( '-s', '--sqrt_scale', action='store_true', | |
415 help = "Square root scale" ) | |
416 arg( '--no_slabels', action='store_true', | |
417 help = "Do not show sample labels" ) | |
418 arg( '--minv', type=float, default=None, | |
419 help = "Minimum value to display in the color map [default None meaning automatic]" ) | |
420 arg( '--maxv', type=float, default=None, | |
421 help = "Maximum value to display in the color map [default None meaning automatic]" ) | |
422 arg( '--no_flabels', action='store_true', | |
423 help = "Do not show feature labels" ) | |
424 arg( '--max_slabel_len', type=int, default=25, | |
425 help = "Max number of chars to report for sample labels [default 15]" ) | |
426 arg( '--max_flabel_len', type=int, default=25, | |
427 help = "Max number of chars to report for feature labels [default 15]" ) | |
428 arg( '--flabel_size', type=int, default=10, | |
429 help = "Feature label font size [default 10]" ) | |
430 arg( '--slabel_size', type=int, default=10, | |
431 help = "Sample label font size [default 10]" ) | |
432 arg( '--fdend_width', type=float, default=1.0, | |
433 help = "Width of the feature dendrogram [default 1 meaning 100%% of default heatmap width]") | |
434 arg( '--sdend_height', type=float, default=1.0, | |
435 help = "Height of the sample dendrogram [default 1 meaning 100%% of default heatmap height]") | |
436 arg( '--metadata_height', type=float, default=.05, | |
437 help = "Height of the metadata panel [default 0.05 meaning 5%% of default heatmap height]") | |
438 arg( '--metadata_separation', type=float, default=.01, | |
439 help = "Distance between the metadata and data panels. [default 0.001 meaning 0.1%% of default heatmap height]") | |
440 arg( '--image_size', type=float, default=8, | |
441 help = "Size of the largest between width and eight size for the image in inches [default 8]") | |
442 arg( '--cell_aspect_ratio', type=float, default=1.0, | |
443 help = "Aspect ratio between width and height for the cells of the heatmap [default 1.0]") | |
444 col_maps = ['Accent', 'Blues', 'BrBG', 'BuGn', 'BuPu', 'Dark2', 'GnBu', | |
445 'Greens', 'Greys', 'OrRd', 'Oranges', 'PRGn', 'Paired', | |
446 'Pastel1', 'Pastel2', 'PiYG', 'PuBu', 'PuBuGn', 'PuOr', | |
447 'PuRd', 'Purples', 'RdBu', 'RdGy', 'RdPu', 'RdYlBu', 'RdYlGn', | |
448 'Reds', 'Set1', 'Set2', 'Set3', 'Spectral', 'YlGn', 'YlGnBu', | |
449 'YlOrBr', 'YlOrRd', 'afmhot', 'autumn', 'binary', 'bone', | |
450 'brg', 'bwr', 'cool', 'copper', 'flag', 'gist_earth', | |
451 'gist_gray', 'gist_heat', 'gist_ncar', 'gist_rainbow', | |
452 'gist_stern', 'gist_yarg', 'gnuplot', 'gnuplot2', 'gray', | |
453 'hot', 'hsv', 'jet', 'ocean', 'pink', 'prism', 'rainbow', | |
454 'seismic', 'spectral', 'spring', 'summer', 'terrain', 'winter'] + [n for n,c in Heatmap.my_colormaps] | |
455 for n,c in Heatmap.my_colormaps: | |
456 my_cmap = matplotlib.colors.LinearSegmentedColormap(n,c,256) | |
457 pylab.register_cmap(name=n,cmap=my_cmap) | |
458 arg( '-c','--colormap', type=str, choices = col_maps, default = 'bbcry' ) | |
459 arg( '--bottom_c', type=str, default = None, | |
460 help = "Color to use for cells below the minimum value of the scale [default None meaning bottom color of the scale]") | |
461 arg( '--top_c', type=str, default = None, | |
462 help = "Color to use for cells below the maximum value of the scale [default None meaning bottom color of the scale]") | |
463 arg( '--nan_c', type=str, default = None, | |
464 help = "Color to use for nan cells [default None]") | |
465 | |
466 | |
467 | |
468 """ | |
469 arg( '--', type=str, default="average", | |
470 help = "Linkage method for feature clustering [default average]") | |
471 arg( '--slinkage', type=str, default="average", | |
472 help = "Linkage method for sample clustering [default average]") | |
473 """ | |
474 | |
475 def __init__( self, numpy_matrix, sdendrogram, fdendrogram, snames, fnames, fnames_meta, args = None ): | |
476 self.numpy_matrix = numpy_matrix | |
477 self.sdendrogram = sdendrogram | |
478 self.fdendrogram = fdendrogram | |
479 self.snames = snames | |
480 self.fnames = fnames | |
481 self.fnames_meta = fnames_meta | |
482 self.ns,self.nf = self.numpy_matrix.shape | |
483 self.args = args | |
484 | |
485 def make_legend( self, dmap, titles, out_fn ): | |
486 figlegend = plt.figure(figsize=(1+3*len(titles),2), frameon = False) | |
487 | |
488 gs = gridspec.GridSpec( 1, len(dmap), wspace = 2.0 ) | |
489 | |
490 for i,(d,title) in enumerate(zip(dmap,titles)): | |
491 legax = plt.subplot(gs[i],frameon = False) | |
492 for k,v in sorted(d.items(),key=lambda x:x[1]): | |
493 rect = Rectangle( [0.0, 0.0], 0.0, 0.0, | |
494 facecolor = self.dcols[v%len(self.dcols)], | |
495 label = k, | |
496 edgecolor='b', lw = 0.0) | |
497 | |
498 legax.add_patch(rect) | |
499 #remove_splines( legax ) | |
500 legax.set_xticks([]) | |
501 legax.set_yticks([]) | |
502 legax.legend( loc = 2, frameon = False, title = title) | |
503 """ | |
504 ncol = legend_ncol, bbox_to_anchor=(1.01, 3.), | |
505 borderpad = 0.0, labelspacing = 0.0, | |
506 handlelength = 0.5, handletextpad = 0.3, | |
507 borderaxespad = 0.0, columnspacing = 0.3, | |
508 prop = {'size':fontsize}, frameon = False) | |
509 """ | |
510 if out_fn: | |
511 figlegend.savefig(out_fn, bbox_inches='tight') | |
512 | |
513 def draw( self ): | |
514 | |
515 rat = float(self.ns)/self.nf | |
516 rat *= self.args.cell_aspect_ratio | |
517 x,y = (self.args.image_size,rat*self.args.image_size) if rat < 1 else (self.args.image_size/rat,self.args.image_size) | |
518 fig = plt.figure( figsize=(x,y), facecolor = 'w' ) | |
519 | |
520 cm = pylab.get_cmap(self.args.colormap) | |
521 bottom_col = [ cm._segmentdata['red'][0][1], | |
522 cm._segmentdata['green'][0][1], | |
523 cm._segmentdata['blue'][0][1] ] | |
524 if self.args.bottom_c: | |
525 bottom_col = self.args.bottom_c | |
526 cm.set_under( bottom_col ) | |
527 top_col = [ cm._segmentdata['red'][-1][1], | |
528 cm._segmentdata['green'][-1][1], | |
529 cm._segmentdata['blue'][-1][1] ] | |
530 if self.args.top_c: | |
531 top_col = self.args.top_c | |
532 cm.set_over( top_col ) | |
533 | |
534 if self.args.nan_c: | |
535 cm.set_bad( self.args.nan_c ) | |
536 | |
537 def make_ticklabels_invisible(ax): | |
538 for tl in ax.get_xticklabels() + ax.get_yticklabels(): | |
539 tl.set_visible(False) | |
540 ax.set_xticks([]) | |
541 ax.set_yticks([]) | |
542 | |
543 def remove_splines( ax ): | |
544 for v in ['right','left','top','bottom']: | |
545 ax.spines[v].set_color('none') | |
546 | |
547 def shrink_labels( labels, n ): | |
548 shrink = lambda x: x[:n/2]+" [...] "+x[-n/2:] | |
549 return [(shrink(str(l)) if len(str(l)) > n else l) for l in labels] | |
550 | |
551 | |
552 #gs = gridspec.GridSpec( 4, 2, | |
553 # width_ratios=[1.0-fr_ns,fr_ns], | |
554 # height_ratios=[.03,0.03,1.0-fr_nf,fr_nf], | |
555 # wspace = 0.0, hspace = 0.0 ) | |
556 | |
557 fr_ns = float(self.ns)/max([self.ns,self.nf]) | |
558 fr_nf = float(self.nf)/max([self.ns,self.nf]) | |
559 | |
560 buf_space = 0.05 | |
561 minv = min( [buf_space*8, 8*rat*buf_space] ) | |
562 if minv < 0.05: | |
563 buf_space /= minv/0.05 | |
564 metadata_height = self.args.metadata_height if type(snames[0]) is tuple and len(snames[0]) > 1 else 0.000001 | |
565 gs = gridspec.GridSpec( 6, 4, | |
566 width_ratios=[ buf_space, buf_space*2, .08*self.args.fdend_width,0.9], | |
567 height_ratios=[ buf_space, buf_space*2, .08*self.args.sdend_height, metadata_height, self.args.metadata_separation, 0.9], | |
568 wspace = 0.0, hspace = 0.0 ) | |
569 | |
570 ax_hm = plt.subplot(gs[23], axisbg = bottom_col ) | |
571 ax_metadata = plt.subplot(gs[15], axisbg = bottom_col ) | |
572 ax_hm_y2 = ax_hm.twinx() | |
573 | |
574 norm_f = matplotlib.colors.Normalize | |
575 if self.args.log_scale: | |
576 norm_f = matplotlib.colors.LogNorm | |
577 elif self.args.sqrt_scale: | |
578 norm_f = SqrtNorm | |
579 minv, maxv = 0.0, None | |
580 | |
581 maps, values, ndv = [], [], 0 | |
582 if type(snames[0]) is tuple and len(snames[0]) > 1: | |
583 metadata = zip(*[list(s[1:]) for s in snames]) | |
584 for m in metadata: | |
585 mmap = dict([(v[1],ndv+v[0]) for v in enumerate(list(set(m)))]) | |
586 values.append([mmap[v] for v in m]) | |
587 ndv += len(mmap) | |
588 maps.append(mmap) | |
589 dcols = [] | |
590 mdmat = np.matrix(values) | |
591 while len(dcols) < ndv: | |
592 dcols += self.dcols | |
593 cmap = matplotlib.colors.ListedColormap(dcols[:ndv]) | |
594 bounds = [float(f)-0.5 for f in range(ndv+1)] | |
595 imm = ax_metadata.imshow( mdmat, #origin='lower', | |
596 interpolation = 'nearest', | |
597 aspect='auto', | |
598 extent = [0, self.nf, 0, self.ns], | |
599 cmap=cmap, | |
600 vmin=bounds[0], | |
601 vmax=bounds[-1], | |
602 ) | |
603 remove_splines( ax_metadata ) | |
604 ax_metadata_y2 = ax_metadata.twinx() | |
605 ax_metadata_y2.set_ylim(0,len(self.fnames_meta)) | |
606 ax_metadata.set_yticks([]) | |
607 ax_metadata_y2.set_ylim(0,len(self.fnames_meta)) | |
608 ax_metadata_y2.tick_params(length=0) | |
609 ax_metadata_y2.set_yticks(np.arange(len(self.fnames_meta))+0.5) | |
610 ax_metadata_y2.set_yticklabels(self.fnames_meta[::-1], va='center',size=self.args.flabel_size) | |
611 else: | |
612 ax_metadata.set_yticks([]) | |
613 | |
614 ax_metadata.set_xticks([]) | |
615 | |
616 im = ax_hm.imshow( self.numpy_matrix, #origin='lower', | |
617 interpolation = 'nearest', aspect='auto', | |
618 extent = [0, self.nf, 0, self.ns], | |
619 cmap=cm, | |
620 vmin=self.args.minv, | |
621 vmax=self.args.maxv, | |
622 norm = norm_f( vmin=minv if minv > 0.0 else None, vmax=maxv) | |
623 ) | |
624 | |
625 #ax_hm.set_ylim([0,800]) | |
626 ax_hm.set_xticks(np.arange(len(list(snames)))+0.5) | |
627 if not self.args.no_slabels: | |
628 snames_short = shrink_labels( list([s[0] for s in snames]) if type(snames[0]) is tuple else snames, self.args.max_slabel_len ) | |
629 ax_hm.set_xticklabels(snames_short,rotation=90,va='top',ha='center',size=self.args.slabel_size) | |
630 else: | |
631 ax_hm.set_xticklabels([]) | |
632 ax_hm_y2.set_ylim([0,self.ns]) | |
633 ax_hm_y2.set_yticks(np.arange(len(fnames))+0.5) | |
634 if not self.args.no_flabels: | |
635 fnames_short = shrink_labels( fnames, self.args.max_flabel_len ) | |
636 ax_hm_y2.set_yticklabels(fnames_short,va='center',size=self.args.flabel_size) | |
637 else: | |
638 ax_hm_y2.set_yticklabels( [] ) | |
639 ax_hm.set_yticks([]) | |
640 remove_splines( ax_hm ) | |
641 ax_hm.tick_params(length=0) | |
642 ax_hm_y2.tick_params(length=0) | |
643 #ax_hm.set_xlim([0,self.ns]) | |
644 ax_cm = plt.subplot(gs[3], axisbg = 'r', frameon = False) | |
645 #fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing = 'proportional', format = ticker.LogFormatterMathtext() ) | |
646 fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing='proportional' if self.args.sqrt_scale else 'uniform' ) # , format = ticker.LogFormatterMathtext() ) | |
647 | |
648 if not self.args.no_sclustering: | |
649 ax_den_top = plt.subplot(gs[11], axisbg = 'r', frameon = False) | |
650 sph._plot_dendrogram( self.sdendrogram['icoord'], self.sdendrogram['dcoord'], self.sdendrogram['ivl'], | |
651 self.ns + 1, self.nf + 1, 1, 'top', no_labels=True, | |
652 color_list=self.sdendrogram['color_list'] ) | |
653 ymax = max([max(a) for a in self.sdendrogram['dcoord']]) | |
654 ax_den_top.set_ylim([0,ymax]) | |
655 make_ticklabels_invisible( ax_den_top ) | |
656 if not self.args.no_fclustering: | |
657 ax_den_right = plt.subplot(gs[22], axisbg = 'b', frameon = False) | |
658 sph._plot_dendrogram( self.fdendrogram['icoord'], self.fdendrogram['dcoord'], self.fdendrogram['ivl'], | |
659 self.ns + 1, self.nf + 1, 1, 'right', no_labels=True, | |
660 color_list=self.fdendrogram['color_list'] ) | |
661 xmax = max([max(a) for a in self.fdendrogram['dcoord']]) | |
662 ax_den_right.set_xlim([xmax,0]) | |
663 make_ticklabels_invisible( ax_den_right ) | |
664 | |
665 | |
666 if not self.args.out: | |
667 plt.show( ) | |
668 else: | |
669 fig.savefig( self.args.out, bbox_inches='tight', dpi = self.args.dpi ) | |
670 if maps: | |
671 self.make_legend( maps, fnames_meta, self.args.legend_file ) | |
672 | |
673 | |
674 | |
675 class ReadCmd: | |
676 | |
677 def __init__( self ): | |
678 import argparse as ap | |
679 import textwrap | |
680 | |
681 p = ap.ArgumentParser( description= "TBA" ) | |
682 arg = p.add_argument | |
683 | |
684 arg( '-i', '--inp', '--in', metavar='INPUT_FILE', type=str, nargs='?', default=sys.stdin, | |
685 help= "The input matrix" ) | |
686 arg( '-o', '--out', metavar='OUTPUT_FILE', type=str, nargs='?', default=None, | |
687 help= "The output image file [image on screen of not specified]" ) | |
688 arg( '--legend_file', metavar='LEGEND_FILE', type=str, nargs='?', default=None, | |
689 help= "The output file for the legend of the provided metadata" ) | |
690 | |
691 input_types = [DataMatrix.datatype,DistMatrix.datatype] | |
692 arg( '-t', '--input_type', metavar='INPUT_TYPE', type=str, choices = input_types, | |
693 default='data_matrix', | |
694 help= "The input type can be a data matrix or distance matrix [default data_matrix]" ) | |
695 | |
696 DataMatrix.input_parameters( p ) | |
697 DistMatrix.input_parameters( p ) | |
698 HClustering.input_parameters( p ) | |
699 Heatmap.input_parameters( p ) | |
700 | |
701 self.args = p.parse_args() | |
702 | |
703 def check_consistency( self ): | |
704 pass | |
705 | |
706 def get_args( self ): | |
707 return self.args | |
708 | |
709 if __name__ == '__main__': | |
710 | |
711 read = ReadCmd( ) | |
712 read.check_consistency() | |
713 args = read.get_args() | |
714 | |
715 if args.input_type == DataMatrix.datatype: | |
716 dm = DataMatrix( args.inp, args ) | |
717 if args.out_table: | |
718 dm.save_matrix( args.out_table ) | |
719 | |
720 distm = DistMatrix( dm.get_numpy_matrix(), args = args ) | |
721 if not args.no_sclustering: | |
722 distm.compute_s_dists() | |
723 if not args.no_fclustering: | |
724 distm.compute_f_dists() | |
725 elif args.input_type == DataMatrix.datatype: | |
726 # distm = read... | |
727 pass | |
728 else: | |
729 pass | |
730 | |
731 cl = HClustering( distm.get_s_dm(), distm.get_f_dm(), args = args ) | |
732 if not args.no_sclustering: | |
733 cl.shcluster() | |
734 if not args.no_fclustering: | |
735 cl.fhcluster() | |
736 | |
737 hmp = dm.get_numpy_matrix() | |
738 fnames = dm.get_fnames() | |
739 snames = dm.get_snames() | |
740 fnames_meta = snames.names[1:] | |
741 #if not args.no_sclustering or not args.no_fclustering ): | |
742 | |
743 hmp = cl.get_reordered_matrix( hmp, sclustering = not args.no_sclustering, fclustering = not args.no_fclustering ) | |
744 if not args.no_sclustering: | |
745 snames = cl.get_reordered_sample_labels( snames ) | |
746 if not args.no_fclustering: | |
747 fnames = cl.get_reordered_feature_labels( fnames ) | |
748 else: | |
749 fnames = fnames[::-1] | |
750 | |
751 hm = Heatmap( hmp, cl.sdendrogram, cl.fdendrogram, snames, fnames, fnames_meta, args = args ) | |
752 hm.draw() | |
753 | |
754 | |
755 | |
756 | |
757 | |
758 |