| 
12
 | 
     1 import pysam, re, string
 | 
| 
 | 
     2 import matplotlib.pyplot as plt
 | 
| 
 | 
     3 import pandas as pd
 | 
| 
 | 
     4 from collections import defaultdict
 | 
| 
 | 
     5 from collections import OrderedDict
 | 
| 
 | 
     6 import argparse
 | 
| 
 | 
     7 
 | 
| 
 | 
     8 class MismatchFrequencies:
 | 
| 
 | 
     9     '''Iterate over a SAM/BAM alignment file, collecting reads with mismatches. One
 | 
| 
 | 
    10     class instance per alignment file. The result_dict attribute will contain a
 | 
| 
 | 
    11     nested dictionary with name, readlength and mismatch count.'''
 | 
| 
 | 
    12     def __init__(self, result_dict={}, alignment_file=None, name="name", minimal_readlength=21, maximal_readlength=21,
 | 
| 
 | 
    13                  number_of_allowed_mismatches=1, ignore_5p_nucleotides=0, ignore_3p_nucleotides=0):
 | 
| 
 | 
    14     
 | 
| 
 | 
    15         self.result_dict = result_dict
 | 
| 
 | 
    16         self.name = name
 | 
| 
 | 
    17         self.minimal_readlength = minimal_readlength
 | 
| 
 | 
    18         self.maximal_readlength = maximal_readlength
 | 
| 
 | 
    19         self.number_of_allowed_mismatches = number_of_allowed_mismatches
 | 
| 
 | 
    20         self.ignore_5p_nucleotides = ignore_5p_nucleotides
 | 
| 
 | 
    21         self.ignore_3p_nucleotides = ignore_3p_nucleotides
 | 
| 
 | 
    22         
 | 
| 
 | 
    23         if alignment_file:
 | 
| 
 | 
    24             self.pysam_alignment = pysam.Samfile(alignment_file)
 | 
| 
 | 
    25             result_dict[name]=self.get_mismatches(self.pysam_alignment, minimal_readlength, maximal_readlength)
 | 
| 
 | 
    26     
 | 
| 
 | 
    27     def get_mismatches(self, pysam_alignment, minimal_readlength, maximal_readlength):
 | 
| 
 | 
    28         mismatch_dict = defaultdict(int)
 | 
| 
 | 
    29         len_dict={}
 | 
| 
 | 
    30         for i in range(minimal_readlength, maximal_readlength+1):
 | 
| 
 | 
    31             len_dict[i]=mismatch_dict.copy()
 | 
| 
 | 
    32         for alignedread in pysam_alignment:
 | 
| 
 | 
    33             if self.read_is_valid(alignedread, minimal_readlength, maximal_readlength):
 | 
| 
 | 
    34                 len_dict[int(alignedread.rlen)]['total valid reads'] += 1
 | 
| 
 | 
    35                 MD=alignedread.opt('MD')
 | 
| 
 | 
    36                 if self.read_has_mismatch(alignedread, self.number_of_allowed_mismatches):
 | 
| 
 | 
    37                     (ref_base, mismatch_base)=self.read_to_reference_mismatch(MD, alignedread.seq, alignedread.is_reverse)
 | 
| 
 | 
    38                     if ref_base == None:
 | 
| 
 | 
    39                             continue
 | 
| 
 | 
    40                     else:
 | 
| 
 | 
    41                         for i, base in enumerate(ref_base):
 | 
| 
 | 
    42                             len_dict[int(alignedread.rlen)][ref_base[i]+' to '+mismatch_base[i]] += 1
 | 
| 
 | 
    43         return len_dict
 | 
| 
 | 
    44     
 | 
| 
 | 
    45     def read_is_valid(self, read, min_readlength, max_readlength):
 | 
| 
 | 
    46         '''Filter out reads that are unmatched, too short or
 | 
| 
 | 
    47         too long or that contian insertions'''
 | 
| 
 | 
    48         if read.is_unmapped:
 | 
| 
 | 
    49             return False
 | 
| 
 | 
    50         if read.rlen < min_readlength:
 | 
| 
 | 
    51             return False
 | 
| 
 | 
    52         if read.rlen > max_readlength:
 | 
| 
 | 
    53             return False
 | 
| 
 | 
    54         else:
 | 
| 
 | 
    55             return True
 | 
| 
 | 
    56     
 | 
| 
 | 
    57     def read_has_mismatch(self, read, number_of_allowed_mismatches=1):
 | 
| 
 | 
    58         '''keep only reads with one mismatch. Could be simplified'''
 | 
| 
 | 
    59         NM=read.opt('NM')
 | 
| 
 | 
    60         if NM <1: #filter out reads with no mismatch
 | 
| 
 | 
    61             return False
 | 
| 
 | 
    62         if NM >number_of_allowed_mismatches: #filter out reads with more than 1 mismtach
 | 
| 
 | 
    63             return False
 | 
| 
 | 
    64         else:
 | 
| 
 | 
    65             return True
 | 
| 
 | 
    66         
 | 
| 
 | 
    67     def mismatch_in_allowed_region(self, readseq, mismatch_position):
 | 
| 
 | 
    68         '''
 | 
| 
 | 
    69         >>> M = MismatchFrequencies()
 | 
| 
 | 
    70         >>> readseq = 'AAAAAA'
 | 
| 
 | 
    71         >>> mismatch_position = 2
 | 
| 
 | 
    72         >>> M.mismatch_in_allowed_region(readseq, mismatch_position)
 | 
| 
 | 
    73         True
 | 
| 
 | 
    74         >>> M = MismatchFrequencies(ignore_3p_nucleotides=2, ignore_5p_nucleotides=2)
 | 
| 
 | 
    75         >>> readseq = 'AAAAAA'
 | 
| 
 | 
    76         >>> mismatch_position = 1
 | 
| 
 | 
    77         >>> M.mismatch_in_allowed_region(readseq, mismatch_position)
 | 
| 
 | 
    78         False
 | 
| 
 | 
    79         >>> readseq = 'AAAAAA'
 | 
| 
 | 
    80         >>> mismatch_position = 4
 | 
| 
 | 
    81         >>> M.mismatch_in_allowed_region(readseq, mismatch_position)
 | 
| 
 | 
    82         False
 | 
| 
 | 
    83         '''
 | 
| 
 | 
    84         mismatch_position+=1 # To compensate for starting the count at 0
 | 
| 
 | 
    85         five_p = self.ignore_5p_nucleotides
 | 
| 
 | 
    86         three_p = self.ignore_3p_nucleotides
 | 
| 
 | 
    87         if any([five_p > 0, three_p > 0]):
 | 
| 
 | 
    88             if any([mismatch_position <= five_p, 
 | 
| 
 | 
    89                     mismatch_position >= (len(readseq)+1-three_p)]): #Again compensate for starting the count at 0
 | 
| 
 | 
    90                 return False
 | 
| 
 | 
    91             else:
 | 
| 
 | 
    92                 return True
 | 
| 
 | 
    93         else:
 | 
| 
 | 
    94             return True
 | 
| 
 | 
    95             
 | 
| 
 | 
    96     def read_to_reference_mismatch(self, MD, readseq, is_reverse):
 | 
| 
 | 
    97         '''
 | 
| 
 | 
    98         This is where the magic happens. The MD tag contains SNP and indel information,
 | 
| 
 | 
    99         without looking to the genome sequence. This is a typical MD tag: 3C0G2A6.
 | 
| 
 | 
   100         3 bases of the read align to the reference, followed by a mismatch, where the
 | 
| 
 | 
   101         reference base is C, followed by 10 bases aligned to the reference. 
 | 
| 
 | 
   102         suppose a reference 'CTTCGATAATCCTT'
 | 
| 
 | 
   103                              |||  || ||||||
 | 
| 
 | 
   104                  and a read 'CTTATATTATCCTT'. 
 | 
| 
 | 
   105         This situation is represented by the above MD tag. 
 | 
| 
 | 
   106         Given MD tag and read sequence this function returns the reference base C, G and A, 
 | 
| 
 | 
   107         and the mismatched base A, T, T.
 | 
| 
 | 
   108         >>> M = MismatchFrequencies()
 | 
| 
 | 
   109         >>> MD='3C0G2A7'
 | 
| 
 | 
   110         >>> seq='CTTATATTATCCTT'
 | 
| 
 | 
   111         >>> result=M.read_to_reference_mismatch(MD, seq, is_reverse=False)
 | 
| 
 | 
   112         >>> result[0]=="CGA"
 | 
| 
 | 
   113         True
 | 
| 
 | 
   114         >>> result[1]=="ATT"
 | 
| 
 | 
   115         True
 | 
| 
 | 
   116         >>> 
 | 
| 
 | 
   117         '''
 | 
| 
 | 
   118         search=re.finditer('[ATGC]',MD)
 | 
| 
 | 
   119         if '^' in MD:
 | 
| 
 | 
   120             print 'WARNING insertion detected, mismatch calling skipped for this read!!!'
 | 
| 
 | 
   121             return (None, None)
 | 
| 
 | 
   122         start_index=0 # refers to the leading integer of the MD string before an edited base
 | 
| 
 | 
   123         current_position=0 # position of the mismatched nucleotide in the MD tag string
 | 
| 
 | 
   124         mismatch_position=0 # position of edited base in current read 
 | 
| 
 | 
   125         reference_base=""
 | 
| 
 | 
   126         mismatched_base=""
 | 
| 
 | 
   127         for result in search:
 | 
| 
 | 
   128             current_position=result.start()
 | 
| 
 | 
   129             mismatch_position=mismatch_position+1+int(MD[start_index:current_position]) #converts the leading characters before an edited base into integers
 | 
| 
 | 
   130             start_index=result.end()
 | 
| 
 | 
   131             reference_base+=MD[result.end()-1]
 | 
| 
 | 
   132             mismatched_base+=readseq[mismatch_position-1]
 | 
| 
 | 
   133         if is_reverse:
 | 
| 
 | 
   134             reference_base=reverseComplement(reference_base)
 | 
| 
 | 
   135             mismatched_base=reverseComplement(mismatched_base)
 | 
| 
 | 
   136             mismatch_position=len(readseq)-mismatch_position-1
 | 
| 
 | 
   137         if mismatched_base=='N':
 | 
| 
 | 
   138             return (None, None)
 | 
| 
 | 
   139         if self.mismatch_in_allowed_region(readseq, mismatch_position):
 | 
| 
 | 
   140             return (reference_base, mismatched_base)
 | 
| 
 | 
   141         else:
 | 
| 
 | 
   142             return (None, None)
 | 
| 
 | 
   143 
 | 
| 
 | 
   144 def reverseComplement(sequence):
 | 
| 
 | 
   145     '''do a reverse complement of DNA base.
 | 
| 
 | 
   146     >>> reverseComplement('ATGC')=='GCAT'
 | 
| 
 | 
   147     True
 | 
| 
 | 
   148     >>> 
 | 
| 
 | 
   149     '''
 | 
| 
 | 
   150     sequence=sequence.upper()
 | 
| 
 | 
   151     complement = string.maketrans('ATCGN', 'TAGCN')
 | 
| 
 | 
   152     return sequence.upper().translate(complement)[::-1]
 | 
| 
 | 
   153 
 | 
| 
 | 
   154 def barplot(df, library, axes):
 | 
| 
 | 
   155     df.plot(kind='bar', ax=axes, subplots=False,\
 | 
| 
 | 
   156             stacked=False, legend='test',\
 | 
| 
 | 
   157             title='Mismatch frequencies for {0}'.format(library))
 | 
| 
 | 
   158   
 | 
| 
 | 
   159 def result_dict_to_df(result_dict):
 | 
| 
 | 
   160     mismatches = []
 | 
| 
 | 
   161     libraries = []
 | 
| 
 | 
   162     for mismatch, library in result_dict.iteritems():
 | 
| 
 | 
   163         mismatches.append(mismatch)
 | 
| 
 | 
   164         libraries.append(pd.DataFrame.from_dict(library, orient='index'))
 | 
| 
 | 
   165     df=pd.concat(libraries, keys=mismatches)
 | 
| 
 | 
   166     df.index.names = ['library', 'readsize']
 | 
| 
 | 
   167     return df
 | 
| 
 | 
   168 
 | 
| 
 | 
   169 def df_to_tab(df, output):
 | 
| 
 | 
   170     df.to_csv(output, sep='\t')
 | 
| 
 | 
   171 
 | 
| 
 | 
   172 def plot_result(result_dict, args):
 | 
| 
 | 
   173     names=args.name
 | 
| 
 | 
   174     nrows=len(names)/2+1
 | 
| 
 | 
   175     fig = plt.figure(figsize=(16,32))
 | 
| 
 | 
   176     for i,library in enumerate (names):
 | 
| 
 | 
   177         axes=fig.add_subplot(nrows,2,i+1)
 | 
| 
 | 
   178         library_dict=result_dict[library]
 | 
| 
 | 
   179         for length in library_dict.keys():
 | 
| 
 | 
   180             for mismatch in library_dict[length]:
 | 
| 
 | 
   181                 if mismatch == 'total valid reads':
 | 
| 
 | 
   182                     continue
 | 
| 
 | 
   183                 library_dict[length][mismatch]=library_dict[length][mismatch]/float(library_dict[length]['total valid reads'])*100
 | 
| 
 | 
   184             del library_dict[length]['total valid reads']
 | 
| 
 | 
   185         df=pd.DataFrame(library_dict)
 | 
| 
 | 
   186         barplot(df, library, axes),
 | 
| 
 | 
   187         axes.set_ylabel('Mismatch count / all valid reads * 100')
 | 
| 
 | 
   188     fig.savefig(args.output_pdf, format='pdf')    
 | 
| 
 | 
   189 
 | 
| 
 | 
   190 def setup_MismatchFrequencies(args):
 | 
| 
 | 
   191     resultDict=OrderedDict()
 | 
| 
 | 
   192     kw_list=[{'result_dict' : resultDict, 
 | 
| 
 | 
   193              'alignment_file' :alignment_file, 
 | 
| 
 | 
   194              'name' : name, 
 | 
| 
 | 
   195              'minimal_readlength' : args.min, 
 | 
| 
 | 
   196              'maximal_readlength' : args.max,
 | 
| 
 | 
   197              'number_of_allowed_mismatches' : args.n_mm,
 | 
| 
 | 
   198              'ignore_5p_nucleotides' : args.five_p, 
 | 
| 
 | 
   199              'ignore_3p_nucleotides' : args.three_p} 
 | 
| 
 | 
   200              for alignment_file, name in zip(args.input, args.name)]
 | 
| 
 | 
   201     return (kw_list, resultDict)
 | 
| 
 | 
   202 
 | 
| 
 | 
   203 def run_MismatchFrequencies(args):
 | 
| 
 | 
   204     kw_list, resultDict=setup_MismatchFrequencies(args)
 | 
| 
 | 
   205     [MismatchFrequencies(**kw_dict) for kw_dict in kw_list]
 | 
| 
 | 
   206     return resultDict
 | 
| 
 | 
   207 
 | 
| 
 | 
   208 def main():
 | 
| 
 | 
   209     result_dict=run_MismatchFrequencies(args)
 | 
| 
 | 
   210     df=result_dict_to_df(result_dict)
 | 
| 
 | 
   211     plot_result(result_dict, args)
 | 
| 
 | 
   212     df_to_tab(df, args.output_tab)
 | 
| 
 | 
   213 
 | 
| 
 | 
   214 if __name__ == "__main__":
 | 
| 
 | 
   215     
 | 
| 
 | 
   216     parser = argparse.ArgumentParser(description='Produce mismatch statistics for BAM/SAM alignment files.')
 | 
| 
 | 
   217     parser.add_argument('--input', nargs='*', help='Input files in SAM/BAM format')
 | 
| 
 | 
   218     parser.add_argument('--name', nargs='*', help='Name for input file to display in output file. Should have same length as the number of inputs')
 | 
| 
 | 
   219     parser.add_argument('--output_pdf', help='Output filename for graph')
 | 
| 
 | 
   220     parser.add_argument('--output_tab', help='Output filename for table')
 | 
| 
 | 
   221     parser.add_argument('--min', '--minimal_readlength', type=int, help='minimum readlength')
 | 
| 
 | 
   222     parser.add_argument('--max', '--maximal_readlength', type=int, help='maximum readlength')
 | 
| 
 | 
   223     parser.add_argument('--n_mm', '--number_allowed_mismatches', type=int, default=1, help='discard reads with more than n mismatches')
 | 
| 
 | 
   224     parser.add_argument('--five_p', '--ignore_5p_nucleotides', type=int, default=0, help='when calculating nucleotide mismatch frequencies ignore the first N nucleotides of the read')
 | 
| 
 | 
   225     parser.add_argument('--three_p', '--ignore_3p_nucleotides', type=int, default=1, help='when calculating nucleotide mismatch frequencies ignore the last N nucleotides of the read')
 | 
| 
 | 
   226     #args = parser.parse_args(['--input', '3mismatches_ago2ip.bam', '2mismatch.bam', '--name', 'Siomi1', 'Siomi2' , '--five_p', '3','--three_p','3','--output_pdf', 'out.pdf', '--output_tab', 'out.tab', '--min', '21', '--max', '21'])
 | 
| 
 | 
   227     args = parser.parse_args()
 | 
| 
 | 
   228     main()
 | 
| 
 | 
   229 
 |