#!/usr/bin/python
#
# Copyright 2017.  Cumulus Networks, Inc.
#
# dbg-nl
#
# Debug utility for tracking netlink overflows. Does the following -
# 1. grabs the output of /proc/net/netlink.
# 2. stashes it under /run/cumulus/dbg_nl/ for computing drop deltas.
# 3. xlates the inode in the netlink to a proc_id and proc_name.
# 4. And displays the total_drops and delta_drops (drops since last read)
#    per netlink socket (along with an user understandable proc_name).
# Note: Default output only displays netlink sockets that have non-zero
# drops. "-v" shows all netlink sockets.
#
# See usage (dbg-nl -h) for instructions

import argparse
import json
import os
import re
import shutil
import sys

class ParseError(RuntimeError):
    pass

class NlEntry:
    '''
    Per-netlink socket info. Computed from the outputs of /proc/net/netlink
    and /proc/<pid>/fd/...
    '''
    def __init__(self, port_id, inode, proc_id, proc_name, drops):
        self.port_id = port_id
        self.inode = inode
        self.proc_name = proc_name
        self.proc_id = proc_id
        self.drops = drops
        self.new_drops = 0

    def get_json_obj(self):
        pInfo = {}
        pInfo['port_id'] = self.port_id
        pInfo['proc_id'] = self.proc_id
        pInfo['proc_name'] = self.proc_name
        pInfo['total_drops'] = self.drops
        pInfo['new_drops'] = self.new_drops
        return pInfo

class Nl:
    '''
    Computes and display delta/total drop stats per netlink socket
    '''
    old_dir = '/run/cumulus/dbg_nl'
    old_file_path = old_dir + '/old_stat'
    nl_file_path = '/proc/net/netlink'

    def __init__(self, json_out, verbose):
        self.json = json_out
        self.verbose = verbose

        # port_id: NlEntry
        self.oldNl = {}
        self.newNl = {}

        # this cache is needed for xlating the inode displayed in
        # /proc/net/netlink to a proc_id/proc_name.
        # proc_id: [inodes]
        self.pid_inodes = {}

        # setup the caches
        self._load_inodes()
        self._load_old_nl()
        self._load_new_nl()

    def __str__(self):
        if self.json:
            jout = []
            for index in self.newNl:
                nle = self.newNl[index]
                if self.verbose or nle.drops:
                    info = nle.get_json_obj()
                    jout.append(info)
            return json.dumps(jout, indent=4) + '\n'
        else:
            dispStr = ''
            dispStr += 'port_id                          proc_name       total_drops         new_drops\n'
            for port_id in self.newNl:
                nle = self.newNl[port_id]
                # this needs to become new_drops
                if self.verbose or nle.drops:
                    dispStr += '%-16d  %24s  %16d  %16d\n' % (nle.port_id, nle.proc_name, nle.drops, nle.new_drops)
            return dispStr

    def _inode_to_proc_name(self, port_id, inode):
        if not port_id:
            # special handling for Pid=0
            return 0, 'kernel'

        proc_id = -1
        for index in self.pid_inodes:
            if inode in self.pid_inodes[index]:
                proc_id = index
                break

        if proc_id == -1:
            return proc_id, 'unk'

        f_name = '/proc/%d/comm' % proc_id
        try:
            with open (f_name, 'r') as f:
                proc_name = f.readline().strip()
        except:
            proc_name = 'unk'
        return proc_id, proc_name

    def _load_inodes(self):
        # sample output -
        #root@cel-sea-03:~# pidof switchd
        #2372
        #root@cel-sea-03:~# ls -l /proc/2372/fd/
        #total 0
        #lr-x------ 1 root root 64 May 23 12:24 0 -> /dev/null
        #lrwx------ 1 root root 64 May 23 12:24 1 -> socket:[2973552]
        #lrwx------ 1 root root 64 May 23 12:24 10 -> socket:[2974424]
        #>>>>>>>>>>>>>>>>>>>>>>>>>>>> SNIP >>>>>>>>>>>>>>>>>>>>>>>>>>>>
        pids = os.listdir('/proc')
        pids = [int(pid) for pid in pids if pid.isdigit()]
        for pid in pids:
            inodes = []
            try:
                fds = os.listdir('/proc/%d/fd' % pid)
            except:
                fds = []
            for fd in fds:
                try:
                    inode_link = os.readlink('/proc/%d/fd/%s' % (pid, fd))
                    pat = re.compile('socket:\[(?P<id>[\d]+)\]')
                    obj = pat.match(inode_link)
                    if obj:
                        inodes.append(int(obj.group('id')))
                except:
                    pass
            self.pid_inodes[pid] = inodes

    def _load_old_nl(self):
        self._load_nl_stat_from_file(self.old_file_path, self.oldNl)

    def _load_new_nl(self):
        self._load_nl_stat_from_file(self.nl_file_path, self.newNl)

        # compute stats delta
        for port_id in self.newNl:
            nle = self.newNl[port_id]
            if port_id in self.oldNl:
                old_nle = self.oldNl[port_id]
                nle.new_drops = nle.drops - old_nle.drops

        # replace old stats with new stats
        if not os.path.exists(self.old_dir):
            try:
                os.makedirs(self.old_dir)
            except OSError:
                pass
        shutil.copy(self.nl_file_path, self.old_file_path)

    def _load_nl_stat_from_file(self, file_path, nlDb):
        # Sample output -
        #root@cel-sea-03:~# cat /proc/net/netlink
        #sk       Eth Pid    Groups   Rmem     Wmem     Dump     Locks     Drops     Inode
        #ffff88016ccc0800 0   19205  00000111 0        0        0 2        0        3454766
        #ffff880178d8f800 0   4294963165 00000000 0        0        0 2        0        15857
        #ffff88016cd92000 0   1371539780 00000000 0        0        0 2        0        2974421
        #ffff88015c7f9000 0   19215  00000005 0        0        0 2        0        3391269
        #ffff88015d8c9000 0   26264  00000555 0        0        0 2        0        3428187

        try:
            with open (file_path, 'r') as f:
                read_data = f.read()
        except:
            read_data = ''

        lines = read_data.splitlines()
        if len(lines) < 2:
            return

        try:
            port_id_index = lines[0].split().index('Pid')
            drops_index = lines[0].split().index('Drops')
            inode_index = lines[0].split().index('Inode')
            max_index = max(port_id_index, drops_index, inode_index)
        except:
            return

        for line in lines[1:]:
            words = line.split()
            if len(words) < max_index:
                continue
            port_id = int(words[port_id_index])
            drops = int(words[drops_index])
            inode = int(words[inode_index])
            proc_id, proc_name = self._inode_to_proc_name(port_id, inode)
            nle = NlEntry(port_id, inode, proc_id, proc_name, drops)
            nlDb[port_id] = nle


if __name__ == '__main__':
    if os.getuid():
        sys.stderr.write("need to run as root\n")
        sys.exit(1)

    parser = argparse.ArgumentParser(
        description='Displays netlink drop stats')
    parser.add_argument('-v', '--verbose', default=False, action='store_true',
                        help='Dislay all nl sockets (including ones with no drops)')
    parser.add_argument('-j', '--json', default=False, action='store_true',
                        help='JSON output')

    try:
        args = parser.parse_args()
    except ParseError, e:
        parser.error(str(e))

    nlInfo = Nl(args.json, args.verbose)

    sys.stdout.write('%s' % nlInfo)
    sys.exit(0)
