Mercurial > repos > test-svm > kmersvm_test
diff kmersvm/scripts/kmersvm_classify.py @ 5:f99b5099ea55 draft
Uploaded
| author | test-svm |
|---|---|
| date | Sun, 05 Aug 2012 16:50:57 -0400 |
| parents | 66088269713e |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/kmersvm/scripts/kmersvm_classify.py Sun Aug 05 16:50:57 2012 -0400 @@ -0,0 +1,221 @@ +#!/usr/bin/python +""" + kmersvm_classify.py; classify sequences using SVM + Copyright (C) 2011 Dongwon Lee + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +""" + +import sys +import numpy +import optparse + +from libkmersvm import * + +""" +global variables +""" +g_kmer2id = {} + + +class Parameters: + def __init__(self, kernel=None, kmerlen=None, kmerlen2=None, bias=None, A=None, B=None): + self.kernel = kernel + self.kmerlen = kmerlen + self.kmerlen2 = kmerlen2 + self.bias = bias + self.A = A + self.B = B + + +def read_svmwfile_wsk(filename): + """read SVM weight file generated by kmersvm_train.py + + Arguments: + filename -- string, name of the SVM weight file + + Return: + list of SVM weights + an object of Parameters class + """ + + try: + f = open(filename, 'r') + lines = f.readlines() + f.close() + + except IOError, (errno, strerror): + print "I/O error(%d): %s" % (errno, strerror) + sys.exit(0) + + kmer_svmw_dict = {} + params = Parameters() + + for line in lines: + #header lines + if line[0] == '#': + #if this line contains '=', that should be evaluated as a parameter + if line.find('=') > 0: + name, value = line[1:].split('=') + vars(params)[name] = value + else: + s = line.split() + kmerlen = len(s[0]) + if kmerlen not in kmer_svmw_dict: + kmer_svmw_dict[kmerlen] = {} + + kmer_svmw_dict[kmerlen][s[0]] = float(s[2]) + + #type casting of parameters + params.kernel = int(params.kernel) + params.kmerlen = int(params.kmerlen) + if params.kernel == 1: + params.kmerlen2 = params.kmerlen + else: + params.kmerlen2 = int(params.kmerlen2) + params.bias = float(params.bias) + params.A = float(params.A) + params.B = float(params.B) + + #set global variable + global g_kmer2id + for k in range(params.kmerlen, params.kmerlen2+1): + kmers = generate_kmers(k) + rcmap = generate_rcmap_table(k, kmers) + for i in xrange(len(kmers)): + g_kmer2id[kmers[i]] = rcmap[i] + + #create numpy arrays of svm weights + svmw_list = [] + for k in range(params.kmerlen, params.kmerlen2+1): + svmw = [0]*(2**(2*k)) + + for kmer in kmer_svmw_dict[k].keys(): + svmw[g_kmer2id[kmer]] = kmer_svmw_dict[k][kmer] + + svmw_list.append(numpy.array(svmw, numpy.double)) + + return svmw_list, params + + +def score_seq(s, svmw, kmerlen): + """calculate SVM score of given sequence using single set of svm weights + + Arguments: + s -- string, DNA sequence + svmw -- numpy array, SVM weights + kmerlen -- integer, length of k-mer of SVM weight + + Return: + SVM score + """ + kmer2id = g_kmer2id + x = [0]*(2**(2*kmerlen)) + for j in xrange(len(s)-kmerlen+1): + x[ kmer2id[s[j:j+kmerlen]] ] += 1 + + x = numpy.array(x, numpy.double) + score_norm = numpy.dot(svmw, x)/numpy.sqrt(numpy.sum(x**2)) + + return score_norm + + +def score_seq_wsk(s, svmwlist, kmerlen_start, kmerlen_end): + """calculate svm score of given sequence with multiple sets of svm weights + + Arguments: + svmwlist -- list, SVM weights + kmerlen_start -- integer, minimum length of k-mer in the list of svm weights + kmerlen_end -- integer, maximum length of k-mer in the list of sv weights + + Return: + SVM score + """ + kmerlens = range(kmerlen_start, kmerlen_end+1) + nkmerlens = len(kmerlens) + + score_norm_sum = 0 + + for i in range(nkmerlens): + score_norm = score_seq(s, svmwlist[i], kmerlens[i]) + score_norm_sum += score_norm + + return score_norm_sum + + +def main(argv = sys.argv): + usage = "Usage: %prog [options] SVM_WEIGHTS TEST_SEQ" + desc = "1. take two files(one is in FASTA format to score, the other is SVM weight file generated from kmersvm_train.py) as input, 2. score each sequence in the given file" + parser = optparse.OptionParser(usage=usage, description=desc) + parser.add_option("-o", dest="output", default="kmersvm_scores.out", \ + help="set the name of output score file (default=kmersvm_scores.out)") + + parser.add_option("-q", dest="quiet", default=False, action="store_true", \ + help="supress messages (default=false)") + + (options, args) = parser.parse_args() + + if len(args) == 0: + parser.print_help() + sys.exit(0) + + if len(args) != 2: + parser.error("incorrect number of arguments") + sys.exit(0) + + ktype_str = ["", "Spectrum", "Weighted Spectrums"] + + svmwf = args[0] + seqf = args[1] + + seqs, sids = read_fastafile(seqf) + svmwlist, params = read_svmwfile_wsk(svmwf) + + if options.quiet == False: + sys.stderr.write('Options:\n') + sys.stderr.write(' kernel-type: ' + str(params.kernel) + "." + ktype_str[params.kernel] + '\n') + sys.stderr.write(' kmerlen: ' + str(params.kmerlen) + '\n') + if params.kernel == 2: + sys.stderr.write(' kmerlen2: ' + str(params.kmerlen2) + '\n') + sys.stderr.write(' output: ' + options.output + '\n') + sys.stderr.write('\n') + + sys.stderr.write('Input args:\n') + sys.stderr.write(' SVM weights file: ' + svmwf + '\n') + sys.stderr.write(' sequence file: ' + seqf + '\n') + sys.stderr.write('\n') + + sys.stderr.write('numer of sequences to score: ' + str(len(seqs)) + '\n') + sys.stderr.write('posteriorp A: ' + str(params.A) + '\n') + sys.stderr.write('posteriorp B: ' + str(params.B) + '\n') + sys.stderr.write('\n') + + f = open(options.output, 'w') + f.write("\t".join(["#seq_id", "posterior_prob", "svm_score\n"])) + + kmerlen = params.kmerlen + kmerlen2 = params.kmerlen2 + bias = params.bias + A = params.A + B = params.B + for sidx in xrange(len(seqs)): + s = seqs[sidx] + score = score_seq_wsk(s, svmwlist, kmerlen, kmerlen2) + bias + pp = 1/(1+numpy.exp(score*A+B)) + + f.write("\t".join([ sids[sidx], str(pp), str(score)]) + "\n") + + f.close() + +if __name__=='__main__': main()
