Mercurial > repos > test-svm > kmersvm_test
comparison kmersvm/scripts/kmersvm_train.py @ 0:66088269713e draft
Uploaded all files tracked by git
| author | test-svm |
|---|---|
| date | Sun, 05 Aug 2012 15:32:16 -0400 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:66088269713e |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 """ | |
| 3 kmersvm_train.py; train a support vector machine using shogun toolbox | |
| 4 Copyright (C) 2011 Dongwon Lee | |
| 5 | |
| 6 This program is free software: you can redistribute it and/or modify | |
| 7 it under the terms of the GNU General Public License as published by | |
| 8 the Free Software Foundation, either version 3 of the License, or | |
| 9 (at your option) any later version. | |
| 10 | |
| 11 This program is distributed in the hope that it will be useful, | |
| 12 but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 14 GNU General Public License for more details. | |
| 15 | |
| 16 You should have received a copy of the GNU General Public License | |
| 17 along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| 18 | |
| 19 | |
| 20 """ | |
| 21 | |
| 22 | |
| 23 | |
| 24 import sys | |
| 25 import optparse | |
| 26 import random | |
| 27 import numpy | |
| 28 from math import log, exp | |
| 29 | |
| 30 from libkmersvm import * | |
| 31 try: | |
| 32 from shogun.PreProc import SortWordString, SortUlongString | |
| 33 except ImportError: | |
| 34 from shogun.Preprocessor import SortWordString, SortUlongString | |
| 35 from shogun.Kernel import CommWordStringKernel, CommUlongStringKernel, \ | |
| 36 CombinedKernel | |
| 37 | |
| 38 from shogun.Features import StringWordFeatures, StringUlongFeatures, \ | |
| 39 StringCharFeatures, CombinedFeatures, DNA, Labels | |
| 40 from shogun.Classifier import MSG_INFO, MSG_ERROR | |
| 41 try: | |
| 42 from shogun.Classifier import SVMLight | |
| 43 except ImportError: | |
| 44 from shogun.Classifier import LibSVM | |
| 45 | |
| 46 """ | |
| 47 global variables | |
| 48 """ | |
| 49 g_kmers = [] | |
| 50 g_rcmap = [] | |
| 51 | |
| 52 | |
| 53 def kmerid2kmer(kmerid, kmerlen): | |
| 54 """convert integer kmerid to kmer string | |
| 55 | |
| 56 Arguments: | |
| 57 kmerid -- integer, id of k-mer | |
| 58 kmerlen -- integer, length of k-mer | |
| 59 | |
| 60 Return: | |
| 61 kmer string | |
| 62 """ | |
| 63 | |
| 64 nts = "ACGT" | |
| 65 kmernts = [] | |
| 66 kmerid2 = kmerid | |
| 67 | |
| 68 for i in xrange(kmerlen): | |
| 69 ntid = kmerid2 % 4 | |
| 70 kmernts.append(nts[ntid]) | |
| 71 kmerid2 = int((kmerid2-ntid)/4) | |
| 72 | |
| 73 return ''.join(reversed(kmernts)) | |
| 74 | |
| 75 | |
| 76 def kmer2kmerid(kmer, kmerlen): | |
| 77 """convert kmer string to integer kmerid | |
| 78 | |
| 79 Arguments: | |
| 80 kmerid -- integer, id of k-mer | |
| 81 kmerlen -- integer, length of k-mer | |
| 82 | |
| 83 Return: | |
| 84 id of k-mer | |
| 85 """ | |
| 86 | |
| 87 nt2id = {'A':0, 'C':1, 'G':2, 'T':3} | |
| 88 | |
| 89 return reduce(lambda x, y: (4*x+y), [nt2id[x] for x in kmer]) | |
| 90 | |
| 91 | |
| 92 def get_rcmap(kmerid, kmerlen): | |
| 93 """mapping kmerid to its reverse complement k-mer on-the-fly | |
| 94 | |
| 95 Arguments: | |
| 96 kmerid -- integer, id of k-mer | |
| 97 kmerlen -- integer, length of k-mer | |
| 98 | |
| 99 Return: | |
| 100 integer kmerid after mapping to its reverse complement | |
| 101 """ | |
| 102 | |
| 103 #1. get kmer from kmerid | |
| 104 #2. get reverse complement kmer | |
| 105 #3. get kmerid from revcomp kmer | |
| 106 rckmerid = kmer2kmerid(revcomp(kmerid2kmer(kmerid, kmerlen)), kmerlen) | |
| 107 | |
| 108 if rckmerid < kmerid: | |
| 109 return rckmerid | |
| 110 | |
| 111 return kmerid | |
| 112 | |
| 113 | |
| 114 def non_redundant_word_features(feats, kmerlen): | |
| 115 """convert the features from Shogun toolbox to non-redundant word features (handle reverse complements) | |
| 116 Arguments: | |
| 117 feats -- StringWordFeatures | |
| 118 kmerlen -- integer, length of k-mer | |
| 119 | |
| 120 Return: | |
| 121 StringWordFeatures after converting reverse complement k-mer ids | |
| 122 """ | |
| 123 | |
| 124 rcmap = g_rcmap | |
| 125 | |
| 126 for i in xrange(feats.get_num_vectors()): | |
| 127 nf = [rcmap[int(kmerid)] for kmerid in feats.get_feature_vector(i)] | |
| 128 | |
| 129 feats.set_feature_vector(numpy.array(nf, numpy.dtype('u2')), i) | |
| 130 | |
| 131 preproc = SortWordString() | |
| 132 preproc.init(feats) | |
| 133 try: | |
| 134 feats.add_preproc(preproc) | |
| 135 feats.apply_preproc() | |
| 136 except AttributeError: | |
| 137 feats.add_preprocessor(preproc) | |
| 138 feats.apply_preprocessor() | |
| 139 | |
| 140 return feats | |
| 141 | |
| 142 | |
| 143 def non_redundant_ulong_features(feats, kmerlen): | |
| 144 """convert the features from Shogun toolbox to non-redundant ulong features | |
| 145 Arguments: | |
| 146 feats -- StringUlongFeatures | |
| 147 kmerlen -- integer, length of k-mer | |
| 148 | |
| 149 Return: | |
| 150 StringUlongFeatures after converting reverse complement k-mer ids | |
| 151 """ | |
| 152 | |
| 153 for i in xrange(feats.get_num_vectors()): | |
| 154 nf = [get_rcmap(int(kmerid), kmerlen) \ | |
| 155 for kmerid in feats.get_feature_vector(i)] | |
| 156 | |
| 157 feats.set_feature_vector(numpy.array(nf, numpy.dtype('u8')), i) | |
| 158 | |
| 159 preproc = SortUlongString() | |
| 160 preproc.init(feats) | |
| 161 try: | |
| 162 feats.add_preproc(preproc) | |
| 163 feats.apply_preproc() | |
| 164 except AttributeError: | |
| 165 feats.add_preprocessor(preproc) | |
| 166 feats.apply_preprocessor() | |
| 167 | |
| 168 return feats | |
| 169 | |
| 170 | |
| 171 def svm_learn(kernel, labels, options): | |
| 172 """train SVM using SVMLight or LibSVM | |
| 173 | |
| 174 Arguments: | |
| 175 kernel -- kernel object from Shogun toolbox | |
| 176 lebels -- list of labels | |
| 177 options -- object containing option data | |
| 178 | |
| 179 Return: | |
| 180 trained svm object | |
| 181 """ | |
| 182 | |
| 183 try: | |
| 184 svm=SVMLight(options.svmC, kernel, Labels(numpy.array(labels, dtype=numpy.double))) | |
| 185 except NameError: | |
| 186 svm=LibSVM(options.svmC, kernel, Labels(numpy.array(labels, dtype=numpy.double))) | |
| 187 | |
| 188 if options.quiet == False: | |
| 189 svm.io.set_loglevel(MSG_INFO) | |
| 190 svm.io.set_target_to_stderr() | |
| 191 | |
| 192 svm.set_epsilon(options.epsilon) | |
| 193 svm.parallel.set_num_threads(1) | |
| 194 if options.weight != 1.0: | |
| 195 svm.set_C(options.svmC, options.svmC*options.weight) | |
| 196 svm.train() | |
| 197 | |
| 198 if options.quiet == False: | |
| 199 svm.io.set_loglevel(MSG_ERROR) | |
| 200 | |
| 201 return svm | |
| 202 | |
| 203 | |
| 204 def _get_spectrum_features(seqs, kmerlen): | |
| 205 """generate spectrum features (internal) | |
| 206 | |
| 207 Arguments: | |
| 208 seqs -- list of sequences | |
| 209 kmerlen -- integer, length of k-mer | |
| 210 | |
| 211 Return: | |
| 212 StringWord(Ulong)Features after treatment of redundant reverse complement k-mers | |
| 213 """ | |
| 214 | |
| 215 char_feats = StringCharFeatures(seqs, DNA) | |
| 216 | |
| 217 if kmerlen <= 8: | |
| 218 string_features = StringWordFeatures | |
| 219 non_redundant_features = non_redundant_word_features | |
| 220 else: | |
| 221 string_features = StringUlongFeatures | |
| 222 non_redundant_features = non_redundant_ulong_features | |
| 223 | |
| 224 feats = string_features(DNA) | |
| 225 feats.obtain_from_char(char_feats, kmerlen-1, kmerlen, 0, False) | |
| 226 return non_redundant_features(feats, kmerlen) | |
| 227 | |
| 228 | |
| 229 def get_spectrum_features(seqs, options): | |
| 230 """generate spectrum features (wrapper) | |
| 231 """ | |
| 232 return _get_spectrum_features(seqs, options.kmerlen) | |
| 233 | |
| 234 | |
| 235 def get_weighted_spectrum_features(seqs, options): | |
| 236 """generate weighted spectrum features | |
| 237 """ | |
| 238 global g_kmers | |
| 239 global g_rcmap | |
| 240 | |
| 241 subfeats_list = [] | |
| 242 | |
| 243 for k in xrange(options.kmerlen, options.kmerlen2+1): | |
| 244 char_feats = StringCharFeatures(seqs, DNA) | |
| 245 if k <= 8: | |
| 246 g_kmers = generate_kmers(k) | |
| 247 g_rcmap = generate_rcmap_table(k, g_kmers) | |
| 248 | |
| 249 subfeats = _get_spectrum_features(seqs, k) | |
| 250 subfeats_list.append(subfeats) | |
| 251 | |
| 252 return subfeats_list | |
| 253 | |
| 254 | |
| 255 def get_spectrum_kernel(feats, options): | |
| 256 """build spectrum kernel with non-redundant k-mer list (removing reverse complement) | |
| 257 | |
| 258 Arguments: | |
| 259 feats -- feature object | |
| 260 options -- object containing option data | |
| 261 | |
| 262 Return: | |
| 263 StringWord(Ulong)Features, CommWord(Ulong)StringKernel | |
| 264 """ | |
| 265 if options.kmerlen <= 8: | |
| 266 return CommWordStringKernel(feats, feats) | |
| 267 else: | |
| 268 return CommUlongStringKernel(feats, feats) | |
| 269 | |
| 270 | |
| 271 def get_weighted_spectrum_kernel(subfeats_list, options): | |
| 272 """build weighted spectrum kernel with non-redundant k-mer list (removing reverse complement) | |
| 273 | |
| 274 Arguments: | |
| 275 subfeats_list -- list of sub-feature objects | |
| 276 options -- object containing option data | |
| 277 | |
| 278 Return: | |
| 279 CombinedFeatures of StringWord(Ulong)Features, CombinedKernel of CommWord(Ulong)StringKernel | |
| 280 """ | |
| 281 kmerlen = options.kmerlen | |
| 282 kmerlen2 = options.kmerlen2 | |
| 283 | |
| 284 subkernels = 0 | |
| 285 kernel = CombinedKernel() | |
| 286 feats = CombinedFeatures() | |
| 287 | |
| 288 for subfeats in subfeats_list: | |
| 289 feats.append_feature_obj(subfeats) | |
| 290 | |
| 291 for k in xrange(kmerlen, kmerlen2+1): | |
| 292 if k <= 8: | |
| 293 subkernel = CommWordStringKernel(10, False) | |
| 294 else: | |
| 295 subkernel = CommUlongStringKernel(10, False) | |
| 296 | |
| 297 kernel.append_kernel(subkernel) | |
| 298 subkernels+=1 | |
| 299 | |
| 300 kernel.init(feats, feats) | |
| 301 | |
| 302 kernel.set_subkernel_weights(numpy.array([1/float(subkernels)]*subkernels, numpy.dtype('float64'))) | |
| 303 | |
| 304 return kernel | |
| 305 | |
| 306 | |
| 307 def init_spectrum_kernel(kern, feats_lhs, feats_rhs): | |
| 308 """initialize spectrum kernel (wrapper function) | |
| 309 """ | |
| 310 kern.init(feats_lhs, feats_rhs) | |
| 311 | |
| 312 | |
| 313 def init_weighted_spectrum_kernel(kern, subfeats_list_lhs, subfeats_list_rhs): | |
| 314 """initialize weighted spectrum kernel (wrapper function) | |
| 315 """ | |
| 316 feats_lhs = CombinedFeatures() | |
| 317 feats_rhs = CombinedFeatures() | |
| 318 | |
| 319 for subfeats in subfeats_list_lhs: | |
| 320 feats_lhs.append_feature_obj(subfeats) | |
| 321 | |
| 322 for subfeats in subfeats_list_rhs: | |
| 323 feats_rhs.append_feature_obj(subfeats) | |
| 324 | |
| 325 kern.init(feats_lhs, feats_rhs) | |
| 326 | |
| 327 | |
| 328 def get_sksvm_weights(svm, feats, options): | |
| 329 """calculate the SVM weight vector of spectrum kernel | |
| 330 """ | |
| 331 kmerlen = options.kmerlen | |
| 332 alphas = svm.get_alphas() | |
| 333 support_vector_ids = svm.get_support_vectors() | |
| 334 | |
| 335 w = numpy.array([0]*(2**(2*kmerlen)), numpy.double) | |
| 336 | |
| 337 for i in xrange(len(alphas)): | |
| 338 x = [0]*(2**(2*kmerlen)) | |
| 339 for kmerid in feats.get_feature_vector(int(support_vector_ids[i])): | |
| 340 x[int(kmerid)] += 1 | |
| 341 x = numpy.array(x, numpy.double) | |
| 342 w += (alphas[i]*x/numpy.sqrt(numpy.sum(x**2))) | |
| 343 | |
| 344 return w | |
| 345 | |
| 346 | |
| 347 def get_wsksvm_weights(svm, subfeats_list, options): | |
| 348 """calculate the SVM weight vector of weighted spectrum kernel | |
| 349 """ | |
| 350 kmerlen = options.kmerlen | |
| 351 kmerlen2 = options.kmerlen2 | |
| 352 alphas = svm.get_alphas() | |
| 353 support_vector_ids = svm.get_support_vectors() | |
| 354 kmerlens = range(kmerlen, kmerlen2+1) | |
| 355 | |
| 356 weights = [] | |
| 357 for idx in xrange(len(kmerlens)): | |
| 358 subfeats = subfeats_list[idx] | |
| 359 | |
| 360 k = kmerlens[idx] | |
| 361 w = numpy.array([0]*(2**(2*k)), numpy.double) | |
| 362 | |
| 363 for i in xrange(len(alphas)): | |
| 364 x = [0]*(2**(2*k)) | |
| 365 for kmerid in subfeats.get_feature_vector(int(support_vector_ids[i])): | |
| 366 x[int(kmerid)] += 1 | |
| 367 x = numpy.array(x, numpy.double) | |
| 368 w += (alphas[i]*x/numpy.sqrt(numpy.sum(x**2))) | |
| 369 | |
| 370 w /= len(kmerlens) | |
| 371 weights.append(w) | |
| 372 | |
| 373 return weights | |
| 374 | |
| 375 | |
| 376 def save_header(f, bias, A, B, options): | |
| 377 f.write("#parameters:\n") | |
| 378 f.write("#kernel=" + str(options.ktype) + "\n") | |
| 379 f.write("#kmerlen=" + str(options.kmerlen) + "\n") | |
| 380 if options.ktype == 2: | |
| 381 f.write("#kmerlen2=" + str(options.kmerlen2) + "\n") | |
| 382 f.write("#bias=" + str(bias) + "\n") | |
| 383 f.write("#A=" + str(A) + "\n") | |
| 384 f.write("#B=" + str(B) + "\n") | |
| 385 f.write("#NOTE: k-mers with large negative weights are also important. They can be found at the bottom of the list.\n") | |
| 386 f.write("#k-mer\trevcomp\tSVM-weight\n") | |
| 387 | |
| 388 | |
| 389 def save_sksvm_weights(w, bias, A, B, options): | |
| 390 """save the SVM weight vector from spectrum kernel | |
| 391 """ | |
| 392 output = options.outputname + "_weights.out" | |
| 393 kmerlen = options.kmerlen | |
| 394 | |
| 395 f = open(output, 'w') | |
| 396 save_header(f, bias, A, B, options) | |
| 397 | |
| 398 global g_kmers | |
| 399 global g_rcmap | |
| 400 | |
| 401 if options.sort: | |
| 402 w_sorted = sorted(zip(range(len(w)), w), key=lambda x: x[1], reverse=True) | |
| 403 else: | |
| 404 w_sorted = zip(range(len(w)), w) | |
| 405 | |
| 406 if kmerlen <= 8: | |
| 407 for i in map(lambda x: x[0], w_sorted): | |
| 408 if i == g_rcmap[i]: | |
| 409 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(w[i])] ) + '\n') | |
| 410 else: | |
| 411 for i in map(lambda x: x[0], w_sorted): | |
| 412 if i == get_rcmap(i, kmerlen): | |
| 413 kmer = kmerid2kmer(i, kmerlen) | |
| 414 f.write('\t'.join( [kmer, revcomp(kmer), str(w[i])] ) + '\n') | |
| 415 | |
| 416 f.close() | |
| 417 | |
| 418 | |
| 419 def save_wsksvm_weights(w, bias, A, B, options): | |
| 420 """save the SVM weight vector from weighted spectrum kernel | |
| 421 """ | |
| 422 output = options.outputname + "_weights.out" | |
| 423 kmerlen = options.kmerlen | |
| 424 kmerlen2 = options.kmerlen2 | |
| 425 | |
| 426 f = open(output, 'w') | |
| 427 save_header(f, bias, A, B, options) | |
| 428 | |
| 429 global g_kmers | |
| 430 global g_rcmap | |
| 431 | |
| 432 kmerlens = range(kmerlen, kmerlen2+1) | |
| 433 for idx in xrange(len(kmerlens)): | |
| 434 k = kmerlens[idx] | |
| 435 subw = w[idx] | |
| 436 | |
| 437 if options.sort: | |
| 438 subw_sorted = sorted(zip(range(len(subw)), subw), key=lambda x: x[1], reverse=True) | |
| 439 else: | |
| 440 subw_sorted = zip(range(len(subw)), subw) | |
| 441 | |
| 442 if k <= 8: | |
| 443 g_kmers = generate_kmers(k) | |
| 444 g_rcmap = generate_rcmap_table(k, g_kmers) | |
| 445 for i in map(lambda x: x[0], subw_sorted): | |
| 446 if i == g_rcmap[i]: | |
| 447 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(subw[i])] ) + "\n") | |
| 448 else: | |
| 449 for i in map(lambda x: x[0], subw_sorted): | |
| 450 if i == get_rcmap(i, k): | |
| 451 kmer = kmerid2kmer(i, k) | |
| 452 f.write('\t'.join( [kmers, revcomp(kmers), str(subw[i])] ) + "\n") | |
| 453 | |
| 454 f.close() | |
| 455 | |
| 456 | |
| 457 def save_predictions(output, preds, cvs): | |
| 458 """save prediction | |
| 459 """ | |
| 460 f = open(output, 'w') | |
| 461 f.write('\t'.join(["#seq_id", "SVM score", "label", "NCV"]) + "\n") | |
| 462 for i in xrange(len(preds)): | |
| 463 f.write('\t'.join([preds[i][1], str(preds[i][2]), str(preds[i][3]), str(cvs[i])]) + "\n") | |
| 464 f.close() | |
| 465 | |
| 466 | |
| 467 def generate_cv_list(ncv, n1, n2): | |
| 468 """generate the N-fold cross validation list | |
| 469 | |
| 470 Arguments: | |
| 471 ncv -- integer, number of cross-validation | |
| 472 n1 -- integer, number of positives | |
| 473 n2 -- integer, number of negatives | |
| 474 | |
| 475 Return: | |
| 476 a list of N-fold cross validation | |
| 477 """ | |
| 478 | |
| 479 shuffled_idx_list1 = range(n1) | |
| 480 shuffled_idx_list2 = range(n1,n1+n2) | |
| 481 | |
| 482 random.shuffle(shuffled_idx_list1) | |
| 483 random.shuffle(shuffled_idx_list2) | |
| 484 | |
| 485 shuffled_idx_list = shuffled_idx_list1 + shuffled_idx_list2 | |
| 486 | |
| 487 idx = 0 | |
| 488 icv = 0 | |
| 489 cv = [0] * (n1+n2) | |
| 490 while(idx < (n1+n2)): | |
| 491 cv[shuffled_idx_list[idx]] = icv | |
| 492 | |
| 493 idx += 1 | |
| 494 icv += 1 | |
| 495 if icv == ncv: | |
| 496 icv = 0 | |
| 497 | |
| 498 return cv | |
| 499 | |
| 500 | |
| 501 def split_cv_list(cvlist, icv, data): | |
| 502 """split data into training and test based on cross-validation list | |
| 503 | |
| 504 Arguments: | |
| 505 cvlist -- list, cross-validation list | |
| 506 icv -- integer, corss-validation set of interest | |
| 507 data -- list, data set to be splitted | |
| 508 | |
| 509 Return: | |
| 510 a list of training set and a list of test set | |
| 511 """ | |
| 512 | |
| 513 tr_data = [] | |
| 514 te_data = [] | |
| 515 | |
| 516 for i in xrange(len(data)): | |
| 517 if cvlist[i] == icv: | |
| 518 te_data.append(data[i]) | |
| 519 else: | |
| 520 tr_data.append(data[i]) | |
| 521 | |
| 522 return tr_data, te_data | |
| 523 | |
| 524 | |
| 525 def LMAI(svms, labels, prior0, prior1): | |
| 526 """fitting svms to sigmoid function (improved version introduced by Lin 2003) | |
| 527 | |
| 528 Arguments: | |
| 529 svms -- list of svm scores | |
| 530 labels -- list of labels | |
| 531 prior0 -- prior of negative set | |
| 532 prior1 -- prior of positive set | |
| 533 | |
| 534 Return: | |
| 535 A, B parameter of 1/(1+exp(A*SVM+B)) | |
| 536 """ | |
| 537 | |
| 538 #parameter settings | |
| 539 maxiter = 100 | |
| 540 minstep = 1e-10 | |
| 541 sigma = 1e-3 | |
| 542 | |
| 543 hiTarget = (prior1+1.0)/float(prior1+2.0) | |
| 544 loTarget = 1/float(prior0+2.0) | |
| 545 | |
| 546 t = [0]*len(labels) | |
| 547 for i in xrange(len(labels)): | |
| 548 if labels[i] == 1: | |
| 549 t[i] = hiTarget | |
| 550 else: | |
| 551 t[i] = loTarget | |
| 552 | |
| 553 A = 0.0 | |
| 554 B = log((prior0+1.0)/float(prior1+1.0)) | |
| 555 fval = 0.0 | |
| 556 | |
| 557 for i in xrange(len(labels)): | |
| 558 fApB = svms[i]*A+B | |
| 559 if fApB >= 0: | |
| 560 fval += (t[i]*fApB+log(1+exp(-fApB))) | |
| 561 else: | |
| 562 fval += ((t[i]-1)*fApB+log(1+exp(fApB))) | |
| 563 | |
| 564 | |
| 565 for it in xrange(maxiter): | |
| 566 #print "iteration:", it | |
| 567 #Update Graidient and Hessian (use H'= H + sigma I) | |
| 568 h11 = sigma | |
| 569 h22 = sigma | |
| 570 h21 = 0.0 | |
| 571 g1 = 0.0 | |
| 572 g2 = 0.0 | |
| 573 | |
| 574 for i in xrange(len(labels)): | |
| 575 fApB = svms[i]*A+B | |
| 576 if fApB >= 0: | |
| 577 p = exp(-fApB) / float(1.0+exp(-fApB)) | |
| 578 q = 1.0 / float(1.0 + exp(-fApB)) | |
| 579 else: | |
| 580 p = 1.0 / float(1.0 + exp(fApB)) | |
| 581 q = exp(fApB) / float(1.0+exp(fApB)) | |
| 582 d2 = p*q | |
| 583 h11 += (svms[i]*svms[i]*d2) | |
| 584 h22 += d2 | |
| 585 h21 += (svms[i]*d2) | |
| 586 d1 = t[i]-p | |
| 587 g1 += (svms[i]*d1) | |
| 588 g2 += d1 | |
| 589 | |
| 590 #Stopping criteria | |
| 591 if (abs(g1)<1e-5) and (abs(g2)<1e-5): | |
| 592 break | |
| 593 | |
| 594 det = h11*h22-h21*h21 | |
| 595 dA = -(h22*g1-h21*g2)/float(det) | |
| 596 dB = -(-h21*g1+h11*g2)/float(det) | |
| 597 gd = g1*dA+g2*dB | |
| 598 stepsize=1 | |
| 599 while stepsize >= minstep: | |
| 600 newA = A+stepsize*dA | |
| 601 newB = B+stepsize*dB | |
| 602 newf = 0.0 | |
| 603 | |
| 604 for i in xrange(len(labels)): | |
| 605 fApB = svms[i]*newA+newB | |
| 606 if fApB >= 0: | |
| 607 newf += (t[i]*fApB + log(1+exp(-fApB))) | |
| 608 else: | |
| 609 newf += ((t[i]-1)*fApB + log(1+exp(fApB))) | |
| 610 | |
| 611 if newf < (fval+0.0001*stepsize*gd): | |
| 612 A=newA | |
| 613 B=newB | |
| 614 fval=newf | |
| 615 break | |
| 616 else: | |
| 617 stepsize=stepsize/float(2.0) | |
| 618 | |
| 619 #Line search failes | |
| 620 if stepsize < minstep: | |
| 621 #print "Line search fails" | |
| 622 break | |
| 623 | |
| 624 #if it >= maxiter: | |
| 625 # print "Reaching maximum iterations" | |
| 626 | |
| 627 return A, B | |
| 628 | |
| 629 | |
| 630 def wsksvm_classify(seqs, svm, kern, feats, options): | |
| 631 feats_te = get_weighted_spectrum_features(seqs, options) | |
| 632 init_weighted_spectrum_kernel(kern, feats, feats_te) | |
| 633 | |
| 634 return svm.apply().get_labels().tolist() | |
| 635 | |
| 636 | |
| 637 def score_seq(s, svmw, kmerlen): | |
| 638 """calculate SVM score of given sequence using single set of svm weights | |
| 639 | |
| 640 Arguments: | |
| 641 s -- string, DNA sequence | |
| 642 svmw -- numpy array, SVM weights | |
| 643 kmerlen -- integer, length of k-mer of SVM weight | |
| 644 | |
| 645 Return: | |
| 646 SVM score | |
| 647 """ | |
| 648 | |
| 649 global g_rcmap | |
| 650 kmer2kmerid_func = kmer2kmerid | |
| 651 | |
| 652 x = [0]*(2**(2*kmerlen)) | |
| 653 for j in xrange(len(s)-kmerlen+1): | |
| 654 x[ g_rcmap[kmer2kmerid_func(s[j:j+kmerlen], kmerlen)] ] += 1 | |
| 655 | |
| 656 x = numpy.array(x, numpy.double) | |
| 657 score_norm = numpy.dot(svmw, x)/numpy.sqrt(numpy.sum(x**2)) | |
| 658 | |
| 659 return score_norm | |
| 660 | |
| 661 | |
| 662 def sksvm_classify(seqs, svm, kern, feats, options): | |
| 663 """classify the given sequences | |
| 664 """ | |
| 665 if options.kmerlen <= 8: | |
| 666 #this is much faster when the length of kmer is short, and SVs are many | |
| 667 svmw = get_sksvm_weights(svm, feats, options) | |
| 668 return [score_seq(s, svmw, options.kmerlen)+svm.get_bias() for s in seqs] | |
| 669 else: | |
| 670 feats_te = get_spectrum_features(seqs, options) | |
| 671 init_spectrum_kernel(kern, feats, feats_te) | |
| 672 | |
| 673 return svm.apply().get_labels().tolist() | |
| 674 | |
| 675 | |
| 676 def main(argv = sys.argv): | |
| 677 usage = "Usage: %prog [options] POSITIVE_SEQ NEGATIVE_SEQ" | |
| 678 desc = "1. take two files(FASTA format) as input, 2. train an SVM and store the trained SVM weights" | |
| 679 parser = optparse.OptionParser(usage=usage, description=desc) | |
| 680 parser.add_option("-t", dest="ktype", type="int", default=1, \ | |
| 681 help="set the type of kernel, 1:Spectrum, 2:Weighted Spectrums (default=1.Spectrum)") | |
| 682 | |
| 683 parser.add_option("-C", dest="svmC", type="float", default=1, \ | |
| 684 help="set the regularization parameter svmC (default=1)") | |
| 685 | |
| 686 parser.add_option("-e", dest="epsilon", type="float", default=0.00001, \ | |
| 687 help="set the precision parameter epsilon (default=0.00001)") | |
| 688 | |
| 689 parser.add_option("-w", dest="weight", type="float", default=0.0, \ | |
| 690 help="set the weight for positive set (default=auto, 1+log(N/P))") | |
| 691 | |
| 692 parser.add_option("-k", dest="kmerlen", type="int",default=6, \ | |
| 693 help="set the (min) length of k-mer for (weighted) spectrum kernel (default = 6)") | |
| 694 | |
| 695 parser.add_option("-K", dest="kmerlen2", type="int",default=8, \ | |
| 696 help="set the max length of k-mer for weighted spectrum kernel (default = 8)") | |
| 697 | |
| 698 parser.add_option("-n", dest="outputname", default="kmersvm_output", \ | |
| 699 help="set the name of output files (default=kmersvm_output)") | |
| 700 | |
| 701 parser.add_option("-v", dest="ncv", type="int", default=0, \ | |
| 702 help="if set, it will perform N-fold cross-validation and generate a prediction file (default = 0)") | |
| 703 | |
| 704 parser.add_option("-p", dest="posteriorp", default=False, action="store_true", \ | |
| 705 help="estimate parameters for posterior probability with N-CV. this option requires -v option to be set (default=false)") | |
| 706 | |
| 707 parser.add_option("-r", dest="rseed", type="int", default=1, \ | |
| 708 help="set the random number seed for cross-validation (-p option) (default=1)") | |
| 709 | |
| 710 parser.add_option("-q", dest="quiet", default=False, action="store_true", \ | |
| 711 help="supress messages (default=false)") | |
| 712 | |
| 713 parser.add_option("-s", dest="sort", default=False, action="store_true", \ | |
| 714 help="sort the kmers by absolute values of SVM weights (default=false)") | |
| 715 | |
| 716 ktype_str = ["", "Spectrum", "Weighted Spectrums"] | |
| 717 | |
| 718 (options, args) = parser.parse_args() | |
| 719 | |
| 720 if len(args) == 0: | |
| 721 parser.print_help() | |
| 722 sys.exit(0) | |
| 723 | |
| 724 if len(args) != 2: | |
| 725 parser.error("incorrect number of arguments") | |
| 726 parser.print_help() | |
| 727 sys.exit(0) | |
| 728 | |
| 729 if options.posteriorp and options.ncv == 0: | |
| 730 parser.error("posterior probability estimation requires N-fold CV process (-v option should be set)") | |
| 731 parser.print_help() | |
| 732 sys.exit(0) | |
| 733 | |
| 734 random.seed(options.rseed) | |
| 735 | |
| 736 """ | |
| 737 set global variable | |
| 738 """ | |
| 739 if (options.ktype == 1) and (options.kmerlen <= 8): | |
| 740 global g_kmers | |
| 741 global g_rcmap | |
| 742 | |
| 743 g_kmers = generate_kmers(options.kmerlen) | |
| 744 g_rcmap = generate_rcmap_table(options.kmerlen, g_kmers) | |
| 745 | |
| 746 posf = args[0] | |
| 747 negf = args[1] | |
| 748 | |
| 749 seqs_pos, sids_pos = read_fastafile(posf) | |
| 750 seqs_neg, sids_neg = read_fastafile(negf) | |
| 751 npos = len(seqs_pos) | |
| 752 nneg = len(seqs_neg) | |
| 753 seqs = seqs_pos + seqs_neg | |
| 754 sids = sids_pos + sids_neg | |
| 755 | |
| 756 if options.weight == 0: | |
| 757 options.weight = 1 + log(nneg/npos) | |
| 758 | |
| 759 if options.quiet == False: | |
| 760 sys.stderr.write('SVM parameters:\n') | |
| 761 sys.stderr.write(' kernel-type: ' + str(options.ktype) + "." + ktype_str[options.ktype] + '\n') | |
| 762 sys.stderr.write(' svm-C: ' + str(options.svmC) + '\n') | |
| 763 sys.stderr.write(' epsilon: ' + str(options.epsilon) + '\n') | |
| 764 sys.stderr.write(' weight: ' + str(options.weight) + '\n') | |
| 765 sys.stderr.write('\n') | |
| 766 | |
| 767 sys.stderr.write('Other options:\n') | |
| 768 sys.stderr.write(' kmerlen: ' + str(options.kmerlen) + '\n') | |
| 769 if options.ktype == 2: | |
| 770 sys.stderr.write(' kmerlen2: ' + str(options.kmerlen2) + '\n') | |
| 771 sys.stderr.write(' outputname: ' + options.outputname + '\n') | |
| 772 sys.stderr.write(' posteriorp: ' + str(options.posteriorp) + '\n') | |
| 773 if options.ncv > 0: | |
| 774 sys.stderr.write(' ncv: ' + str(options.ncv) + '\n') | |
| 775 sys.stderr.write(' rseed: ' + str(options.rseed) + '\n') | |
| 776 sys.stderr.write(' sorted-weight: ' + str(options.sort) + '\n') | |
| 777 | |
| 778 sys.stderr.write('\n') | |
| 779 | |
| 780 sys.stderr.write('Input args:\n') | |
| 781 sys.stderr.write(' positive sequence file: ' + posf + '\n') | |
| 782 sys.stderr.write(' negative sequence file: ' + negf + '\n') | |
| 783 sys.stderr.write('\n') | |
| 784 | |
| 785 sys.stderr.write('numer of total positive seqs: ' + str(npos) + '\n') | |
| 786 sys.stderr.write('numer of total negative seqs: ' + str(nneg) + '\n') | |
| 787 sys.stderr.write('\n') | |
| 788 | |
| 789 #generate labels | |
| 790 labels = [1]*npos + [-1]*nneg | |
| 791 | |
| 792 if options.ktype == 1: | |
| 793 get_features = get_spectrum_features | |
| 794 get_kernel = get_spectrum_kernel | |
| 795 get_weights = get_sksvm_weights | |
| 796 save_weights = save_sksvm_weights | |
| 797 svm_classify = sksvm_classify | |
| 798 elif options.ktype == 2: | |
| 799 get_features = get_weighted_spectrum_features | |
| 800 get_kernel = get_weighted_spectrum_kernel | |
| 801 get_weights = get_wsksvm_weights | |
| 802 save_weights = save_wsksvm_weights | |
| 803 svm_classify = wsksvm_classify | |
| 804 else: | |
| 805 sys.stderr.write('..unknown kernel..\n') | |
| 806 sys.exit(0) | |
| 807 | |
| 808 A = B = 0 | |
| 809 if options.ncv > 0: | |
| 810 if options.quiet == False: | |
| 811 sys.stderr.write('..Cross-validation\n') | |
| 812 | |
| 813 cvlist = generate_cv_list(options.ncv, npos, nneg) | |
| 814 labels_cv = [] | |
| 815 preds_cv = [] | |
| 816 sids_cv = [] | |
| 817 indices_cv = [] | |
| 818 for icv in xrange(options.ncv): | |
| 819 #split data into training and test set | |
| 820 seqs_tr, seqs_te = split_cv_list(cvlist, icv, seqs) | |
| 821 labs_tr, labs_te = split_cv_list(cvlist, icv, labels) | |
| 822 sids_tr, sids_te = split_cv_list(cvlist, icv, sids) | |
| 823 indices_tr, indices_te = split_cv_list(cvlist, icv, range(len(seqs))) | |
| 824 | |
| 825 #train SVM | |
| 826 feats_tr = get_features(seqs_tr, options) | |
| 827 kernel_tr = get_kernel(feats_tr, options) | |
| 828 svm_cv = svm_learn(kernel_tr, labs_tr, options) | |
| 829 | |
| 830 preds_cv = preds_cv + svm_classify(seqs_te, svm_cv, kernel_tr, feats_tr, options) | |
| 831 | |
| 832 labels_cv = labels_cv + labs_te | |
| 833 sids_cv = sids_cv + sids_te | |
| 834 indices_cv = indices_cv + indices_te | |
| 835 | |
| 836 output_cvpred = options.outputname + "_cvpred.out" | |
| 837 prediction_results = sorted(zip(indices_cv, sids_cv, preds_cv, labels_cv), key=lambda p: p[0]) | |
| 838 save_predictions(output_cvpred, prediction_results, cvlist) | |
| 839 | |
| 840 if options.posteriorp: | |
| 841 A, B = LMAI(preds_cv, labels_cv, labels_cv.count(-1), labels_cv.count(1)) | |
| 842 | |
| 843 if options.quiet == False: | |
| 844 sys.stderr.write('Estimated Parameters:\n') | |
| 845 sys.stderr.write(' A: ' + str(A) + '\n') | |
| 846 sys.stderr.write(' B: ' + str(B) + '\n') | |
| 847 | |
| 848 if options.quiet == False: | |
| 849 sys.stderr.write('..SVM weights\n') | |
| 850 | |
| 851 feats = get_features(seqs, options) | |
| 852 kernel = get_kernel(feats, options) | |
| 853 svm = svm_learn(kernel, labels, options) | |
| 854 w = get_weights(svm, feats, options) | |
| 855 b = svm.get_bias() | |
| 856 | |
| 857 save_weights(w, b, A, B, options) | |
| 858 | |
| 859 if __name__=='__main__': main() |
