Mercurial > repos > test-svm > kmersvm_test
comparison kmersvm/scripts/kmersvm_classify.py @ 5:f99b5099ea55 draft
Uploaded
| author | test-svm |
|---|---|
| date | Sun, 05 Aug 2012 16:50:57 -0400 |
| parents | 66088269713e |
| children |
comparison
equal
deleted
inserted
replaced
| 4:f2130156fd5d | 5:f99b5099ea55 |
|---|---|
| 1 #!/usr/bin/python | |
| 2 """ | |
| 3 kmersvm_classify.py; classify sequences using SVM | |
| 4 Copyright (C) 2011 Dongwon Lee | |
| 5 | |
| 6 This program is free software: you can redistribute it and/or modify | |
| 7 it under the terms of the GNU General Public License as published by | |
| 8 the Free Software Foundation, either version 3 of the License, or | |
| 9 (at your option) any later version. | |
| 10 | |
| 11 This program is distributed in the hope that it will be useful, | |
| 12 but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 14 GNU General Public License for more details. | |
| 15 | |
| 16 You should have received a copy of the GNU General Public License | |
| 17 along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| 18 """ | |
| 19 | |
| 20 import sys | |
| 21 import numpy | |
| 22 import optparse | |
| 23 | |
| 24 from libkmersvm import * | |
| 25 | |
| 26 """ | |
| 27 global variables | |
| 28 """ | |
| 29 g_kmer2id = {} | |
| 30 | |
| 31 | |
| 32 class Parameters: | |
| 33 def __init__(self, kernel=None, kmerlen=None, kmerlen2=None, bias=None, A=None, B=None): | |
| 34 self.kernel = kernel | |
| 35 self.kmerlen = kmerlen | |
| 36 self.kmerlen2 = kmerlen2 | |
| 37 self.bias = bias | |
| 38 self.A = A | |
| 39 self.B = B | |
| 40 | |
| 41 | |
| 42 def read_svmwfile_wsk(filename): | |
| 43 """read SVM weight file generated by kmersvm_train.py | |
| 44 | |
| 45 Arguments: | |
| 46 filename -- string, name of the SVM weight file | |
| 47 | |
| 48 Return: | |
| 49 list of SVM weights | |
| 50 an object of Parameters class | |
| 51 """ | |
| 52 | |
| 53 try: | |
| 54 f = open(filename, 'r') | |
| 55 lines = f.readlines() | |
| 56 f.close() | |
| 57 | |
| 58 except IOError, (errno, strerror): | |
| 59 print "I/O error(%d): %s" % (errno, strerror) | |
| 60 sys.exit(0) | |
| 61 | |
| 62 kmer_svmw_dict = {} | |
| 63 params = Parameters() | |
| 64 | |
| 65 for line in lines: | |
| 66 #header lines | |
| 67 if line[0] == '#': | |
| 68 #if this line contains '=', that should be evaluated as a parameter | |
| 69 if line.find('=') > 0: | |
| 70 name, value = line[1:].split('=') | |
| 71 vars(params)[name] = value | |
| 72 else: | |
| 73 s = line.split() | |
| 74 kmerlen = len(s[0]) | |
| 75 if kmerlen not in kmer_svmw_dict: | |
| 76 kmer_svmw_dict[kmerlen] = {} | |
| 77 | |
| 78 kmer_svmw_dict[kmerlen][s[0]] = float(s[2]) | |
| 79 | |
| 80 #type casting of parameters | |
| 81 params.kernel = int(params.kernel) | |
| 82 params.kmerlen = int(params.kmerlen) | |
| 83 if params.kernel == 1: | |
| 84 params.kmerlen2 = params.kmerlen | |
| 85 else: | |
| 86 params.kmerlen2 = int(params.kmerlen2) | |
| 87 params.bias = float(params.bias) | |
| 88 params.A = float(params.A) | |
| 89 params.B = float(params.B) | |
| 90 | |
| 91 #set global variable | |
| 92 global g_kmer2id | |
| 93 for k in range(params.kmerlen, params.kmerlen2+1): | |
| 94 kmers = generate_kmers(k) | |
| 95 rcmap = generate_rcmap_table(k, kmers) | |
| 96 for i in xrange(len(kmers)): | |
| 97 g_kmer2id[kmers[i]] = rcmap[i] | |
| 98 | |
| 99 #create numpy arrays of svm weights | |
| 100 svmw_list = [] | |
| 101 for k in range(params.kmerlen, params.kmerlen2+1): | |
| 102 svmw = [0]*(2**(2*k)) | |
| 103 | |
| 104 for kmer in kmer_svmw_dict[k].keys(): | |
| 105 svmw[g_kmer2id[kmer]] = kmer_svmw_dict[k][kmer] | |
| 106 | |
| 107 svmw_list.append(numpy.array(svmw, numpy.double)) | |
| 108 | |
| 109 return svmw_list, params | |
| 110 | |
| 111 | |
| 112 def score_seq(s, svmw, kmerlen): | |
| 113 """calculate SVM score of given sequence using single set of svm weights | |
| 114 | |
| 115 Arguments: | |
| 116 s -- string, DNA sequence | |
| 117 svmw -- numpy array, SVM weights | |
| 118 kmerlen -- integer, length of k-mer of SVM weight | |
| 119 | |
| 120 Return: | |
| 121 SVM score | |
| 122 """ | |
| 123 kmer2id = g_kmer2id | |
| 124 x = [0]*(2**(2*kmerlen)) | |
| 125 for j in xrange(len(s)-kmerlen+1): | |
| 126 x[ kmer2id[s[j:j+kmerlen]] ] += 1 | |
| 127 | |
| 128 x = numpy.array(x, numpy.double) | |
| 129 score_norm = numpy.dot(svmw, x)/numpy.sqrt(numpy.sum(x**2)) | |
| 130 | |
| 131 return score_norm | |
| 132 | |
| 133 | |
| 134 def score_seq_wsk(s, svmwlist, kmerlen_start, kmerlen_end): | |
| 135 """calculate svm score of given sequence with multiple sets of svm weights | |
| 136 | |
| 137 Arguments: | |
| 138 svmwlist -- list, SVM weights | |
| 139 kmerlen_start -- integer, minimum length of k-mer in the list of svm weights | |
| 140 kmerlen_end -- integer, maximum length of k-mer in the list of sv weights | |
| 141 | |
| 142 Return: | |
| 143 SVM score | |
| 144 """ | |
| 145 kmerlens = range(kmerlen_start, kmerlen_end+1) | |
| 146 nkmerlens = len(kmerlens) | |
| 147 | |
| 148 score_norm_sum = 0 | |
| 149 | |
| 150 for i in range(nkmerlens): | |
| 151 score_norm = score_seq(s, svmwlist[i], kmerlens[i]) | |
| 152 score_norm_sum += score_norm | |
| 153 | |
| 154 return score_norm_sum | |
| 155 | |
| 156 | |
| 157 def main(argv = sys.argv): | |
| 158 usage = "Usage: %prog [options] SVM_WEIGHTS TEST_SEQ" | |
| 159 desc = "1. take two files(one is in FASTA format to score, the other is SVM weight file generated from kmersvm_train.py) as input, 2. score each sequence in the given file" | |
| 160 parser = optparse.OptionParser(usage=usage, description=desc) | |
| 161 parser.add_option("-o", dest="output", default="kmersvm_scores.out", \ | |
| 162 help="set the name of output score file (default=kmersvm_scores.out)") | |
| 163 | |
| 164 parser.add_option("-q", dest="quiet", default=False, action="store_true", \ | |
| 165 help="supress messages (default=false)") | |
| 166 | |
| 167 (options, args) = parser.parse_args() | |
| 168 | |
| 169 if len(args) == 0: | |
| 170 parser.print_help() | |
| 171 sys.exit(0) | |
| 172 | |
| 173 if len(args) != 2: | |
| 174 parser.error("incorrect number of arguments") | |
| 175 sys.exit(0) | |
| 176 | |
| 177 ktype_str = ["", "Spectrum", "Weighted Spectrums"] | |
| 178 | |
| 179 svmwf = args[0] | |
| 180 seqf = args[1] | |
| 181 | |
| 182 seqs, sids = read_fastafile(seqf) | |
| 183 svmwlist, params = read_svmwfile_wsk(svmwf) | |
| 184 | |
| 185 if options.quiet == False: | |
| 186 sys.stderr.write('Options:\n') | |
| 187 sys.stderr.write(' kernel-type: ' + str(params.kernel) + "." + ktype_str[params.kernel] + '\n') | |
| 188 sys.stderr.write(' kmerlen: ' + str(params.kmerlen) + '\n') | |
| 189 if params.kernel == 2: | |
| 190 sys.stderr.write(' kmerlen2: ' + str(params.kmerlen2) + '\n') | |
| 191 sys.stderr.write(' output: ' + options.output + '\n') | |
| 192 sys.stderr.write('\n') | |
| 193 | |
| 194 sys.stderr.write('Input args:\n') | |
| 195 sys.stderr.write(' SVM weights file: ' + svmwf + '\n') | |
| 196 sys.stderr.write(' sequence file: ' + seqf + '\n') | |
| 197 sys.stderr.write('\n') | |
| 198 | |
| 199 sys.stderr.write('numer of sequences to score: ' + str(len(seqs)) + '\n') | |
| 200 sys.stderr.write('posteriorp A: ' + str(params.A) + '\n') | |
| 201 sys.stderr.write('posteriorp B: ' + str(params.B) + '\n') | |
| 202 sys.stderr.write('\n') | |
| 203 | |
| 204 f = open(options.output, 'w') | |
| 205 f.write("\t".join(["#seq_id", "posterior_prob", "svm_score\n"])) | |
| 206 | |
| 207 kmerlen = params.kmerlen | |
| 208 kmerlen2 = params.kmerlen2 | |
| 209 bias = params.bias | |
| 210 A = params.A | |
| 211 B = params.B | |
| 212 for sidx in xrange(len(seqs)): | |
| 213 s = seqs[sidx] | |
| 214 score = score_seq_wsk(s, svmwlist, kmerlen, kmerlen2) + bias | |
| 215 pp = 1/(1+numpy.exp(score*A+B)) | |
| 216 | |
| 217 f.write("\t".join([ sids[sidx], str(pp), str(score)]) + "\n") | |
| 218 | |
| 219 f.close() | |
| 220 | |
| 221 if __name__=='__main__': main() |
