comparison kmersvm/scripts/kmersvm_train.py @ 0:66088269713e draft

Uploaded all files tracked by git
author test-svm
date Sun, 05 Aug 2012 15:32:16 -0400
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:66088269713e
1 #!/usr/bin/env python
2 """
3 kmersvm_train.py; train a support vector machine using shogun toolbox
4 Copyright (C) 2011 Dongwon Lee
5
6 This program is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program. If not, see <http://www.gnu.org/licenses/>.
18
19
20 """
21
22
23
24 import sys
25 import optparse
26 import random
27 import numpy
28 from math import log, exp
29
30 from libkmersvm import *
31 try:
32 from shogun.PreProc import SortWordString, SortUlongString
33 except ImportError:
34 from shogun.Preprocessor import SortWordString, SortUlongString
35 from shogun.Kernel import CommWordStringKernel, CommUlongStringKernel, \
36 CombinedKernel
37
38 from shogun.Features import StringWordFeatures, StringUlongFeatures, \
39 StringCharFeatures, CombinedFeatures, DNA, Labels
40 from shogun.Classifier import MSG_INFO, MSG_ERROR
41 try:
42 from shogun.Classifier import SVMLight
43 except ImportError:
44 from shogun.Classifier import LibSVM
45
46 """
47 global variables
48 """
49 g_kmers = []
50 g_rcmap = []
51
52
53 def kmerid2kmer(kmerid, kmerlen):
54 """convert integer kmerid to kmer string
55
56 Arguments:
57 kmerid -- integer, id of k-mer
58 kmerlen -- integer, length of k-mer
59
60 Return:
61 kmer string
62 """
63
64 nts = "ACGT"
65 kmernts = []
66 kmerid2 = kmerid
67
68 for i in xrange(kmerlen):
69 ntid = kmerid2 % 4
70 kmernts.append(nts[ntid])
71 kmerid2 = int((kmerid2-ntid)/4)
72
73 return ''.join(reversed(kmernts))
74
75
76 def kmer2kmerid(kmer, kmerlen):
77 """convert kmer string to integer kmerid
78
79 Arguments:
80 kmerid -- integer, id of k-mer
81 kmerlen -- integer, length of k-mer
82
83 Return:
84 id of k-mer
85 """
86
87 nt2id = {'A':0, 'C':1, 'G':2, 'T':3}
88
89 return reduce(lambda x, y: (4*x+y), [nt2id[x] for x in kmer])
90
91
92 def get_rcmap(kmerid, kmerlen):
93 """mapping kmerid to its reverse complement k-mer on-the-fly
94
95 Arguments:
96 kmerid -- integer, id of k-mer
97 kmerlen -- integer, length of k-mer
98
99 Return:
100 integer kmerid after mapping to its reverse complement
101 """
102
103 #1. get kmer from kmerid
104 #2. get reverse complement kmer
105 #3. get kmerid from revcomp kmer
106 rckmerid = kmer2kmerid(revcomp(kmerid2kmer(kmerid, kmerlen)), kmerlen)
107
108 if rckmerid < kmerid:
109 return rckmerid
110
111 return kmerid
112
113
114 def non_redundant_word_features(feats, kmerlen):
115 """convert the features from Shogun toolbox to non-redundant word features (handle reverse complements)
116 Arguments:
117 feats -- StringWordFeatures
118 kmerlen -- integer, length of k-mer
119
120 Return:
121 StringWordFeatures after converting reverse complement k-mer ids
122 """
123
124 rcmap = g_rcmap
125
126 for i in xrange(feats.get_num_vectors()):
127 nf = [rcmap[int(kmerid)] for kmerid in feats.get_feature_vector(i)]
128
129 feats.set_feature_vector(numpy.array(nf, numpy.dtype('u2')), i)
130
131 preproc = SortWordString()
132 preproc.init(feats)
133 try:
134 feats.add_preproc(preproc)
135 feats.apply_preproc()
136 except AttributeError:
137 feats.add_preprocessor(preproc)
138 feats.apply_preprocessor()
139
140 return feats
141
142
143 def non_redundant_ulong_features(feats, kmerlen):
144 """convert the features from Shogun toolbox to non-redundant ulong features
145 Arguments:
146 feats -- StringUlongFeatures
147 kmerlen -- integer, length of k-mer
148
149 Return:
150 StringUlongFeatures after converting reverse complement k-mer ids
151 """
152
153 for i in xrange(feats.get_num_vectors()):
154 nf = [get_rcmap(int(kmerid), kmerlen) \
155 for kmerid in feats.get_feature_vector(i)]
156
157 feats.set_feature_vector(numpy.array(nf, numpy.dtype('u8')), i)
158
159 preproc = SortUlongString()
160 preproc.init(feats)
161 try:
162 feats.add_preproc(preproc)
163 feats.apply_preproc()
164 except AttributeError:
165 feats.add_preprocessor(preproc)
166 feats.apply_preprocessor()
167
168 return feats
169
170
171 def svm_learn(kernel, labels, options):
172 """train SVM using SVMLight or LibSVM
173
174 Arguments:
175 kernel -- kernel object from Shogun toolbox
176 lebels -- list of labels
177 options -- object containing option data
178
179 Return:
180 trained svm object
181 """
182
183 try:
184 svm=SVMLight(options.svmC, kernel, Labels(numpy.array(labels, dtype=numpy.double)))
185 except NameError:
186 svm=LibSVM(options.svmC, kernel, Labels(numpy.array(labels, dtype=numpy.double)))
187
188 if options.quiet == False:
189 svm.io.set_loglevel(MSG_INFO)
190 svm.io.set_target_to_stderr()
191
192 svm.set_epsilon(options.epsilon)
193 svm.parallel.set_num_threads(1)
194 if options.weight != 1.0:
195 svm.set_C(options.svmC, options.svmC*options.weight)
196 svm.train()
197
198 if options.quiet == False:
199 svm.io.set_loglevel(MSG_ERROR)
200
201 return svm
202
203
204 def _get_spectrum_features(seqs, kmerlen):
205 """generate spectrum features (internal)
206
207 Arguments:
208 seqs -- list of sequences
209 kmerlen -- integer, length of k-mer
210
211 Return:
212 StringWord(Ulong)Features after treatment of redundant reverse complement k-mers
213 """
214
215 char_feats = StringCharFeatures(seqs, DNA)
216
217 if kmerlen <= 8:
218 string_features = StringWordFeatures
219 non_redundant_features = non_redundant_word_features
220 else:
221 string_features = StringUlongFeatures
222 non_redundant_features = non_redundant_ulong_features
223
224 feats = string_features(DNA)
225 feats.obtain_from_char(char_feats, kmerlen-1, kmerlen, 0, False)
226 return non_redundant_features(feats, kmerlen)
227
228
229 def get_spectrum_features(seqs, options):
230 """generate spectrum features (wrapper)
231 """
232 return _get_spectrum_features(seqs, options.kmerlen)
233
234
235 def get_weighted_spectrum_features(seqs, options):
236 """generate weighted spectrum features
237 """
238 global g_kmers
239 global g_rcmap
240
241 subfeats_list = []
242
243 for k in xrange(options.kmerlen, options.kmerlen2+1):
244 char_feats = StringCharFeatures(seqs, DNA)
245 if k <= 8:
246 g_kmers = generate_kmers(k)
247 g_rcmap = generate_rcmap_table(k, g_kmers)
248
249 subfeats = _get_spectrum_features(seqs, k)
250 subfeats_list.append(subfeats)
251
252 return subfeats_list
253
254
255 def get_spectrum_kernel(feats, options):
256 """build spectrum kernel with non-redundant k-mer list (removing reverse complement)
257
258 Arguments:
259 feats -- feature object
260 options -- object containing option data
261
262 Return:
263 StringWord(Ulong)Features, CommWord(Ulong)StringKernel
264 """
265 if options.kmerlen <= 8:
266 return CommWordStringKernel(feats, feats)
267 else:
268 return CommUlongStringKernel(feats, feats)
269
270
271 def get_weighted_spectrum_kernel(subfeats_list, options):
272 """build weighted spectrum kernel with non-redundant k-mer list (removing reverse complement)
273
274 Arguments:
275 subfeats_list -- list of sub-feature objects
276 options -- object containing option data
277
278 Return:
279 CombinedFeatures of StringWord(Ulong)Features, CombinedKernel of CommWord(Ulong)StringKernel
280 """
281 kmerlen = options.kmerlen
282 kmerlen2 = options.kmerlen2
283
284 subkernels = 0
285 kernel = CombinedKernel()
286 feats = CombinedFeatures()
287
288 for subfeats in subfeats_list:
289 feats.append_feature_obj(subfeats)
290
291 for k in xrange(kmerlen, kmerlen2+1):
292 if k <= 8:
293 subkernel = CommWordStringKernel(10, False)
294 else:
295 subkernel = CommUlongStringKernel(10, False)
296
297 kernel.append_kernel(subkernel)
298 subkernels+=1
299
300 kernel.init(feats, feats)
301
302 kernel.set_subkernel_weights(numpy.array([1/float(subkernels)]*subkernels, numpy.dtype('float64')))
303
304 return kernel
305
306
307 def init_spectrum_kernel(kern, feats_lhs, feats_rhs):
308 """initialize spectrum kernel (wrapper function)
309 """
310 kern.init(feats_lhs, feats_rhs)
311
312
313 def init_weighted_spectrum_kernel(kern, subfeats_list_lhs, subfeats_list_rhs):
314 """initialize weighted spectrum kernel (wrapper function)
315 """
316 feats_lhs = CombinedFeatures()
317 feats_rhs = CombinedFeatures()
318
319 for subfeats in subfeats_list_lhs:
320 feats_lhs.append_feature_obj(subfeats)
321
322 for subfeats in subfeats_list_rhs:
323 feats_rhs.append_feature_obj(subfeats)
324
325 kern.init(feats_lhs, feats_rhs)
326
327
328 def get_sksvm_weights(svm, feats, options):
329 """calculate the SVM weight vector of spectrum kernel
330 """
331 kmerlen = options.kmerlen
332 alphas = svm.get_alphas()
333 support_vector_ids = svm.get_support_vectors()
334
335 w = numpy.array([0]*(2**(2*kmerlen)), numpy.double)
336
337 for i in xrange(len(alphas)):
338 x = [0]*(2**(2*kmerlen))
339 for kmerid in feats.get_feature_vector(int(support_vector_ids[i])):
340 x[int(kmerid)] += 1
341 x = numpy.array(x, numpy.double)
342 w += (alphas[i]*x/numpy.sqrt(numpy.sum(x**2)))
343
344 return w
345
346
347 def get_wsksvm_weights(svm, subfeats_list, options):
348 """calculate the SVM weight vector of weighted spectrum kernel
349 """
350 kmerlen = options.kmerlen
351 kmerlen2 = options.kmerlen2
352 alphas = svm.get_alphas()
353 support_vector_ids = svm.get_support_vectors()
354 kmerlens = range(kmerlen, kmerlen2+1)
355
356 weights = []
357 for idx in xrange(len(kmerlens)):
358 subfeats = subfeats_list[idx]
359
360 k = kmerlens[idx]
361 w = numpy.array([0]*(2**(2*k)), numpy.double)
362
363 for i in xrange(len(alphas)):
364 x = [0]*(2**(2*k))
365 for kmerid in subfeats.get_feature_vector(int(support_vector_ids[i])):
366 x[int(kmerid)] += 1
367 x = numpy.array(x, numpy.double)
368 w += (alphas[i]*x/numpy.sqrt(numpy.sum(x**2)))
369
370 w /= len(kmerlens)
371 weights.append(w)
372
373 return weights
374
375
376 def save_header(f, bias, A, B, options):
377 f.write("#parameters:\n")
378 f.write("#kernel=" + str(options.ktype) + "\n")
379 f.write("#kmerlen=" + str(options.kmerlen) + "\n")
380 if options.ktype == 2:
381 f.write("#kmerlen2=" + str(options.kmerlen2) + "\n")
382 f.write("#bias=" + str(bias) + "\n")
383 f.write("#A=" + str(A) + "\n")
384 f.write("#B=" + str(B) + "\n")
385 f.write("#NOTE: k-mers with large negative weights are also important. They can be found at the bottom of the list.\n")
386 f.write("#k-mer\trevcomp\tSVM-weight\n")
387
388
389 def save_sksvm_weights(w, bias, A, B, options):
390 """save the SVM weight vector from spectrum kernel
391 """
392 output = options.outputname + "_weights.out"
393 kmerlen = options.kmerlen
394
395 f = open(output, 'w')
396 save_header(f, bias, A, B, options)
397
398 global g_kmers
399 global g_rcmap
400
401 if options.sort:
402 w_sorted = sorted(zip(range(len(w)), w), key=lambda x: x[1], reverse=True)
403 else:
404 w_sorted = zip(range(len(w)), w)
405
406 if kmerlen <= 8:
407 for i in map(lambda x: x[0], w_sorted):
408 if i == g_rcmap[i]:
409 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(w[i])] ) + '\n')
410 else:
411 for i in map(lambda x: x[0], w_sorted):
412 if i == get_rcmap(i, kmerlen):
413 kmer = kmerid2kmer(i, kmerlen)
414 f.write('\t'.join( [kmer, revcomp(kmer), str(w[i])] ) + '\n')
415
416 f.close()
417
418
419 def save_wsksvm_weights(w, bias, A, B, options):
420 """save the SVM weight vector from weighted spectrum kernel
421 """
422 output = options.outputname + "_weights.out"
423 kmerlen = options.kmerlen
424 kmerlen2 = options.kmerlen2
425
426 f = open(output, 'w')
427 save_header(f, bias, A, B, options)
428
429 global g_kmers
430 global g_rcmap
431
432 kmerlens = range(kmerlen, kmerlen2+1)
433 for idx in xrange(len(kmerlens)):
434 k = kmerlens[idx]
435 subw = w[idx]
436
437 if options.sort:
438 subw_sorted = sorted(zip(range(len(subw)), subw), key=lambda x: x[1], reverse=True)
439 else:
440 subw_sorted = zip(range(len(subw)), subw)
441
442 if k <= 8:
443 g_kmers = generate_kmers(k)
444 g_rcmap = generate_rcmap_table(k, g_kmers)
445 for i in map(lambda x: x[0], subw_sorted):
446 if i == g_rcmap[i]:
447 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(subw[i])] ) + "\n")
448 else:
449 for i in map(lambda x: x[0], subw_sorted):
450 if i == get_rcmap(i, k):
451 kmer = kmerid2kmer(i, k)
452 f.write('\t'.join( [kmers, revcomp(kmers), str(subw[i])] ) + "\n")
453
454 f.close()
455
456
457 def save_predictions(output, preds, cvs):
458 """save prediction
459 """
460 f = open(output, 'w')
461 f.write('\t'.join(["#seq_id", "SVM score", "label", "NCV"]) + "\n")
462 for i in xrange(len(preds)):
463 f.write('\t'.join([preds[i][1], str(preds[i][2]), str(preds[i][3]), str(cvs[i])]) + "\n")
464 f.close()
465
466
467 def generate_cv_list(ncv, n1, n2):
468 """generate the N-fold cross validation list
469
470 Arguments:
471 ncv -- integer, number of cross-validation
472 n1 -- integer, number of positives
473 n2 -- integer, number of negatives
474
475 Return:
476 a list of N-fold cross validation
477 """
478
479 shuffled_idx_list1 = range(n1)
480 shuffled_idx_list2 = range(n1,n1+n2)
481
482 random.shuffle(shuffled_idx_list1)
483 random.shuffle(shuffled_idx_list2)
484
485 shuffled_idx_list = shuffled_idx_list1 + shuffled_idx_list2
486
487 idx = 0
488 icv = 0
489 cv = [0] * (n1+n2)
490 while(idx < (n1+n2)):
491 cv[shuffled_idx_list[idx]] = icv
492
493 idx += 1
494 icv += 1
495 if icv == ncv:
496 icv = 0
497
498 return cv
499
500
501 def split_cv_list(cvlist, icv, data):
502 """split data into training and test based on cross-validation list
503
504 Arguments:
505 cvlist -- list, cross-validation list
506 icv -- integer, corss-validation set of interest
507 data -- list, data set to be splitted
508
509 Return:
510 a list of training set and a list of test set
511 """
512
513 tr_data = []
514 te_data = []
515
516 for i in xrange(len(data)):
517 if cvlist[i] == icv:
518 te_data.append(data[i])
519 else:
520 tr_data.append(data[i])
521
522 return tr_data, te_data
523
524
525 def LMAI(svms, labels, prior0, prior1):
526 """fitting svms to sigmoid function (improved version introduced by Lin 2003)
527
528 Arguments:
529 svms -- list of svm scores
530 labels -- list of labels
531 prior0 -- prior of negative set
532 prior1 -- prior of positive set
533
534 Return:
535 A, B parameter of 1/(1+exp(A*SVM+B))
536 """
537
538 #parameter settings
539 maxiter = 100
540 minstep = 1e-10
541 sigma = 1e-3
542
543 hiTarget = (prior1+1.0)/float(prior1+2.0)
544 loTarget = 1/float(prior0+2.0)
545
546 t = [0]*len(labels)
547 for i in xrange(len(labels)):
548 if labels[i] == 1:
549 t[i] = hiTarget
550 else:
551 t[i] = loTarget
552
553 A = 0.0
554 B = log((prior0+1.0)/float(prior1+1.0))
555 fval = 0.0
556
557 for i in xrange(len(labels)):
558 fApB = svms[i]*A+B
559 if fApB >= 0:
560 fval += (t[i]*fApB+log(1+exp(-fApB)))
561 else:
562 fval += ((t[i]-1)*fApB+log(1+exp(fApB)))
563
564
565 for it in xrange(maxiter):
566 #print "iteration:", it
567 #Update Graidient and Hessian (use H'= H + sigma I)
568 h11 = sigma
569 h22 = sigma
570 h21 = 0.0
571 g1 = 0.0
572 g2 = 0.0
573
574 for i in xrange(len(labels)):
575 fApB = svms[i]*A+B
576 if fApB >= 0:
577 p = exp(-fApB) / float(1.0+exp(-fApB))
578 q = 1.0 / float(1.0 + exp(-fApB))
579 else:
580 p = 1.0 / float(1.0 + exp(fApB))
581 q = exp(fApB) / float(1.0+exp(fApB))
582 d2 = p*q
583 h11 += (svms[i]*svms[i]*d2)
584 h22 += d2
585 h21 += (svms[i]*d2)
586 d1 = t[i]-p
587 g1 += (svms[i]*d1)
588 g2 += d1
589
590 #Stopping criteria
591 if (abs(g1)<1e-5) and (abs(g2)<1e-5):
592 break
593
594 det = h11*h22-h21*h21
595 dA = -(h22*g1-h21*g2)/float(det)
596 dB = -(-h21*g1+h11*g2)/float(det)
597 gd = g1*dA+g2*dB
598 stepsize=1
599 while stepsize >= minstep:
600 newA = A+stepsize*dA
601 newB = B+stepsize*dB
602 newf = 0.0
603
604 for i in xrange(len(labels)):
605 fApB = svms[i]*newA+newB
606 if fApB >= 0:
607 newf += (t[i]*fApB + log(1+exp(-fApB)))
608 else:
609 newf += ((t[i]-1)*fApB + log(1+exp(fApB)))
610
611 if newf < (fval+0.0001*stepsize*gd):
612 A=newA
613 B=newB
614 fval=newf
615 break
616 else:
617 stepsize=stepsize/float(2.0)
618
619 #Line search failes
620 if stepsize < minstep:
621 #print "Line search fails"
622 break
623
624 #if it >= maxiter:
625 # print "Reaching maximum iterations"
626
627 return A, B
628
629
630 def wsksvm_classify(seqs, svm, kern, feats, options):
631 feats_te = get_weighted_spectrum_features(seqs, options)
632 init_weighted_spectrum_kernel(kern, feats, feats_te)
633
634 return svm.apply().get_labels().tolist()
635
636
637 def score_seq(s, svmw, kmerlen):
638 """calculate SVM score of given sequence using single set of svm weights
639
640 Arguments:
641 s -- string, DNA sequence
642 svmw -- numpy array, SVM weights
643 kmerlen -- integer, length of k-mer of SVM weight
644
645 Return:
646 SVM score
647 """
648
649 global g_rcmap
650 kmer2kmerid_func = kmer2kmerid
651
652 x = [0]*(2**(2*kmerlen))
653 for j in xrange(len(s)-kmerlen+1):
654 x[ g_rcmap[kmer2kmerid_func(s[j:j+kmerlen], kmerlen)] ] += 1
655
656 x = numpy.array(x, numpy.double)
657 score_norm = numpy.dot(svmw, x)/numpy.sqrt(numpy.sum(x**2))
658
659 return score_norm
660
661
662 def sksvm_classify(seqs, svm, kern, feats, options):
663 """classify the given sequences
664 """
665 if options.kmerlen <= 8:
666 #this is much faster when the length of kmer is short, and SVs are many
667 svmw = get_sksvm_weights(svm, feats, options)
668 return [score_seq(s, svmw, options.kmerlen)+svm.get_bias() for s in seqs]
669 else:
670 feats_te = get_spectrum_features(seqs, options)
671 init_spectrum_kernel(kern, feats, feats_te)
672
673 return svm.apply().get_labels().tolist()
674
675
676 def main(argv = sys.argv):
677 usage = "Usage: %prog [options] POSITIVE_SEQ NEGATIVE_SEQ"
678 desc = "1. take two files(FASTA format) as input, 2. train an SVM and store the trained SVM weights"
679 parser = optparse.OptionParser(usage=usage, description=desc)
680 parser.add_option("-t", dest="ktype", type="int", default=1, \
681 help="set the type of kernel, 1:Spectrum, 2:Weighted Spectrums (default=1.Spectrum)")
682
683 parser.add_option("-C", dest="svmC", type="float", default=1, \
684 help="set the regularization parameter svmC (default=1)")
685
686 parser.add_option("-e", dest="epsilon", type="float", default=0.00001, \
687 help="set the precision parameter epsilon (default=0.00001)")
688
689 parser.add_option("-w", dest="weight", type="float", default=0.0, \
690 help="set the weight for positive set (default=auto, 1+log(N/P))")
691
692 parser.add_option("-k", dest="kmerlen", type="int",default=6, \
693 help="set the (min) length of k-mer for (weighted) spectrum kernel (default = 6)")
694
695 parser.add_option("-K", dest="kmerlen2", type="int",default=8, \
696 help="set the max length of k-mer for weighted spectrum kernel (default = 8)")
697
698 parser.add_option("-n", dest="outputname", default="kmersvm_output", \
699 help="set the name of output files (default=kmersvm_output)")
700
701 parser.add_option("-v", dest="ncv", type="int", default=0, \
702 help="if set, it will perform N-fold cross-validation and generate a prediction file (default = 0)")
703
704 parser.add_option("-p", dest="posteriorp", default=False, action="store_true", \
705 help="estimate parameters for posterior probability with N-CV. this option requires -v option to be set (default=false)")
706
707 parser.add_option("-r", dest="rseed", type="int", default=1, \
708 help="set the random number seed for cross-validation (-p option) (default=1)")
709
710 parser.add_option("-q", dest="quiet", default=False, action="store_true", \
711 help="supress messages (default=false)")
712
713 parser.add_option("-s", dest="sort", default=False, action="store_true", \
714 help="sort the kmers by absolute values of SVM weights (default=false)")
715
716 ktype_str = ["", "Spectrum", "Weighted Spectrums"]
717
718 (options, args) = parser.parse_args()
719
720 if len(args) == 0:
721 parser.print_help()
722 sys.exit(0)
723
724 if len(args) != 2:
725 parser.error("incorrect number of arguments")
726 parser.print_help()
727 sys.exit(0)
728
729 if options.posteriorp and options.ncv == 0:
730 parser.error("posterior probability estimation requires N-fold CV process (-v option should be set)")
731 parser.print_help()
732 sys.exit(0)
733
734 random.seed(options.rseed)
735
736 """
737 set global variable
738 """
739 if (options.ktype == 1) and (options.kmerlen <= 8):
740 global g_kmers
741 global g_rcmap
742
743 g_kmers = generate_kmers(options.kmerlen)
744 g_rcmap = generate_rcmap_table(options.kmerlen, g_kmers)
745
746 posf = args[0]
747 negf = args[1]
748
749 seqs_pos, sids_pos = read_fastafile(posf)
750 seqs_neg, sids_neg = read_fastafile(negf)
751 npos = len(seqs_pos)
752 nneg = len(seqs_neg)
753 seqs = seqs_pos + seqs_neg
754 sids = sids_pos + sids_neg
755
756 if options.weight == 0:
757 options.weight = 1 + log(nneg/npos)
758
759 if options.quiet == False:
760 sys.stderr.write('SVM parameters:\n')
761 sys.stderr.write(' kernel-type: ' + str(options.ktype) + "." + ktype_str[options.ktype] + '\n')
762 sys.stderr.write(' svm-C: ' + str(options.svmC) + '\n')
763 sys.stderr.write(' epsilon: ' + str(options.epsilon) + '\n')
764 sys.stderr.write(' weight: ' + str(options.weight) + '\n')
765 sys.stderr.write('\n')
766
767 sys.stderr.write('Other options:\n')
768 sys.stderr.write(' kmerlen: ' + str(options.kmerlen) + '\n')
769 if options.ktype == 2:
770 sys.stderr.write(' kmerlen2: ' + str(options.kmerlen2) + '\n')
771 sys.stderr.write(' outputname: ' + options.outputname + '\n')
772 sys.stderr.write(' posteriorp: ' + str(options.posteriorp) + '\n')
773 if options.ncv > 0:
774 sys.stderr.write(' ncv: ' + str(options.ncv) + '\n')
775 sys.stderr.write(' rseed: ' + str(options.rseed) + '\n')
776 sys.stderr.write(' sorted-weight: ' + str(options.sort) + '\n')
777
778 sys.stderr.write('\n')
779
780 sys.stderr.write('Input args:\n')
781 sys.stderr.write(' positive sequence file: ' + posf + '\n')
782 sys.stderr.write(' negative sequence file: ' + negf + '\n')
783 sys.stderr.write('\n')
784
785 sys.stderr.write('numer of total positive seqs: ' + str(npos) + '\n')
786 sys.stderr.write('numer of total negative seqs: ' + str(nneg) + '\n')
787 sys.stderr.write('\n')
788
789 #generate labels
790 labels = [1]*npos + [-1]*nneg
791
792 if options.ktype == 1:
793 get_features = get_spectrum_features
794 get_kernel = get_spectrum_kernel
795 get_weights = get_sksvm_weights
796 save_weights = save_sksvm_weights
797 svm_classify = sksvm_classify
798 elif options.ktype == 2:
799 get_features = get_weighted_spectrum_features
800 get_kernel = get_weighted_spectrum_kernel
801 get_weights = get_wsksvm_weights
802 save_weights = save_wsksvm_weights
803 svm_classify = wsksvm_classify
804 else:
805 sys.stderr.write('..unknown kernel..\n')
806 sys.exit(0)
807
808 A = B = 0
809 if options.ncv > 0:
810 if options.quiet == False:
811 sys.stderr.write('..Cross-validation\n')
812
813 cvlist = generate_cv_list(options.ncv, npos, nneg)
814 labels_cv = []
815 preds_cv = []
816 sids_cv = []
817 indices_cv = []
818 for icv in xrange(options.ncv):
819 #split data into training and test set
820 seqs_tr, seqs_te = split_cv_list(cvlist, icv, seqs)
821 labs_tr, labs_te = split_cv_list(cvlist, icv, labels)
822 sids_tr, sids_te = split_cv_list(cvlist, icv, sids)
823 indices_tr, indices_te = split_cv_list(cvlist, icv, range(len(seqs)))
824
825 #train SVM
826 feats_tr = get_features(seqs_tr, options)
827 kernel_tr = get_kernel(feats_tr, options)
828 svm_cv = svm_learn(kernel_tr, labs_tr, options)
829
830 preds_cv = preds_cv + svm_classify(seqs_te, svm_cv, kernel_tr, feats_tr, options)
831
832 labels_cv = labels_cv + labs_te
833 sids_cv = sids_cv + sids_te
834 indices_cv = indices_cv + indices_te
835
836 output_cvpred = options.outputname + "_cvpred.out"
837 prediction_results = sorted(zip(indices_cv, sids_cv, preds_cv, labels_cv), key=lambda p: p[0])
838 save_predictions(output_cvpred, prediction_results, cvlist)
839
840 if options.posteriorp:
841 A, B = LMAI(preds_cv, labels_cv, labels_cv.count(-1), labels_cv.count(1))
842
843 if options.quiet == False:
844 sys.stderr.write('Estimated Parameters:\n')
845 sys.stderr.write(' A: ' + str(A) + '\n')
846 sys.stderr.write(' B: ' + str(B) + '\n')
847
848 if options.quiet == False:
849 sys.stderr.write('..SVM weights\n')
850
851 feats = get_features(seqs, options)
852 kernel = get_kernel(feats, options)
853 svm = svm_learn(kernel, labels, options)
854 w = get_weights(svm, feats, options)
855 b = svm.get_bias()
856
857 save_weights(w, b, A, B, options)
858
859 if __name__=='__main__': main()