#!/usr/bin/env python
# Copyright 2016-2020 Cumulus Networks, Inc.
#
# mroute-check kernel helper

import re
import os
import subprocess
import sys

import cumulus.porttab

class Kmroute:
    def __init__(self, sip, dip, iif, oil, orig_str):
        self.sip = sip
        self.dip = dip
        self.iif = iif
        self.oil = oil
        self.is_ulm_orig = 0
        self.is_ulm_term = 0
        self.orig_str = orig_str

    def __str__(self):
        return self.orig_str

class Kernel:
    _devs_supported = ('bridge', 'bond', 'vlan')
    ORIG_DEV = "lo"
    TERM_DEV = "ipmr-lo"

    def __init__(self, asic_present):
        self.desc = "kernel"
        self.asic_present = asic_present
        if asic_present:
            self.pt = cumulus.porttab.porttab()
            self.linux_ports = self.pt.get_linux_ports()
        # key: devName info: link_kind
        self.dev_kind = {}

        # locaye all possible VTEP local ip
        self._parse_lo_addrs()

        # pull the kernel mroute table
        # dict key: '<vrf-devname>, (1.1.1.1, 239.1.1.1)', info: Kmroute
        self.mroutes = {}
        # entries that cannot be added to the hw are maintained in
        # a black list. this is needed to match up with FRR.
        self.blist_mroutes = {}
        self.unresolved_mroutes = {}
        self._mroute_parse_iproute2_output()

    def __str__(self):
        ipmr_str = 'Kernel ip mroutes:\n'
        for key in sorted(self.mroutes):
            ipmr_str += '%s\n' % self.mroutes[key]
        if self.blist_mroutes:
            ipmr_str += 'Kernel ip mroutes (skip-hw):\n'
            for key in sorted(self.blist_mroutes):
                ipmr_str += '%s\n' % self.blist_mroutes[key]
        if self.unresolved_mroutes:
            ipmr_str += 'Kernel ip mroutes (unresolved):\n'
            for key in sorted(self.unresolved_mroutes):
                ipmr_str += '%s\n' % self.unresolved_mroutes[key]
        return ipmr_str

    @staticmethod
    def _get_cmd_output(cmdList, dispStr=None):
        try:
            with open(os.devnull, 'w') as fnull:
                dispStr = subprocess.check_output(cmdList, stderr=fnull)
        except subprocess.CalledProcessError:
            pass
        return dispStr

    def _mroute_install_hw_dev_ok(self, dev):
        if dev in self.linux_ports:
            return True
        if dev == self.TERM_DEV:
            return True

        # check if the device is in the cache
        kind = self.dev_kind.get(dev)
        if not kind:
            # if not in cache get and cache it
            cmd = '/bin/ip -d link show %s' % dev
            dispStr = self._get_cmd_output(cmd.split(), '')

            lines = dispStr.splitlines()
            if not lines or len(lines) < 3 or 'does not exist' in lines[0]:
                return False

            kind = lines[2].split()[0]
            self.dev_kind[dev] = kind

        return True if kind in self._devs_supported else False

    def _mroute_install_hw_ok(self, mroute):
        if not self.asic_present:
            # no chip inside; nothing to blacklist
            return True

        if mroute.is_ulm_orig:
            # Underlay mcast routes are special, handle differently
            return True

        if not self._mroute_install_hw_dev_ok(mroute.iif):
            return False

        for oif in mroute.oil:
            if not self._mroute_install_hw_dev_ok(oif):
                if mroute.is_ulm_term == 1:
                    mroute.is_ulm_term = 0;
                return False

        return True

    def _mroute_key2sip_and_dip(self, key):
        pat=re.compile('\((?P<sip>\S+),(?P<dip>\S+)\)')
        key = key.replace(" ", "")
        obj = pat.match(key)
        if obj:
            return obj.group('sip'), obj.group('dip')
        return None, None

    def mroute_clag_check_and_get_orig_dev(self):
        try:
            cmd = '/usr/bin/clagctl -v'
            dispStrClag = self._get_cmd_output(cmd.split(), '')
        except Exception:
            return

        if len(dispStrClag) == 0:
             return

        # Sample output:
        # root@dell-s5248-01:~# clagctl -v
        # The peer is alive
        # Our Priority, ID, and Role: 4096 54:bf:64:ba:49:e2 primary
        # Peer Priority, ID, and Role: 8192 00:02:00:00:00:32 secondary
        #  Peer Interface and IP: peerlink-3.4094 169.254.0.10
        #       VxLAN Anycast IP: 36.0.0.11
        #              Backup IP: 27.0.0.12 (active)
        # ...

        clag_output_lines = dispStrClag.splitlines()
        if 'The peer is alive' == clag_output_lines[0]:
            for line in clag_output_lines:
                if 'Peer Interface and IP:' in line:
                    self.ORIG_DEV = line.split()[4]
                    break
        return

    def _mroute_parse_iproute2_output(self):

        self.mroute_clag_check_and_get_orig_dev()
        cmd = '/bin/ip mroute show table all'
        dispStr = self._get_cmd_output(cmd.split(), '')

        # sample format (two possible patterns) -
        # (33.1.1.1, 239.1.1.1)            Iif: swp1       Oifs: swp2 swp4  State: resolved
        # (33.1.1.1, 239.1.1.1)            Iif: swp1       State: resolved
        pat1 = re.compile('(?P<key>.*)Iif: (?P<iif>\S+)\s+Oifs: (?P<oil>.*)\s+State: (?P<state>\S+)\s+Table: (?P<table>\S+)')
        pat2 = re.compile('(?P<key>.*)Iif: (?P<iif>\S+)\s+State: (?P<state>\S+)\s+Table: (?P<table>\S+)')
        for line in dispStr.splitlines():
            if 'Oifs' in line:
                pat = pat1
                oilPresent = True
            else:
                pat = pat2
                oilPresent = False
            obj = pat.match(line)
            if not obj:
                continue

            key = obj.group('key').strip()
            sip, dip = self._mroute_key2sip_and_dip(key)
            iif = obj.group('iif')
            oil = obj.group('oil').split() if oilPresent else []
            state = obj.group('state')
            table = obj.group('table')
            key = '%s, %s' % (table, key)

            mroute = Kmroute(sip, dip, iif, oil, line)

            # check for underlay multicast special routes
            if iif == self.ORIG_DEV and sip != '0.0.0.0':
                mroute.is_ulm_orig = 1
            elif self.TERM_DEV in oil and sip == '0.0.0.0':
                mroute.is_ulm_term = 1

            # unresolved cache entries are local to the kernel
            if state == 'unresolved':
                self.unresolved_mroutes[key] = mroute
                continue
            # mroutes that cannot be added to the hw are black listed
            if self._mroute_install_hw_ok(mroute):
                self.mroutes[key] = mroute
            else:
                self.blist_mroutes[key] = mroute

    def _parse_lo_addrs(self):
        pat = re.compile('inet (?P<sip>\S+) ')
        cmd = '/bin/ip -d addr show dev lo'
        dispStr = self._get_cmd_output(cmd.split(), '')
        self.lo_addrs = pat.findall(dispStr)

    def get_mroutes(self):
        # return a set of mroute keys
        return set(self.mroutes)

    def get_blist_mroutes(self):
        # return a set of mroute keys
        return set(self.blist_mroutes)

    def find_mroute(self, key):
        mroute = self.mroutes.get(key)
        if not mroute:
            # look in the black list
            mroute = self.blist_mroutes.get(key)
        return mroute

    def get_mroute_iif_and_oil(self, key):
        mroute = self.find_mroute(key)
        if mroute:
            return mroute.iif, mroute.oil
        return None, None

    def get_lo_addrs(self):
        return self.lo_addrs

    def dump_mroutes(self, filter_mroutes=None):
        if not filter_mroutes:
            for key in sorted(self.mroutes):
                sys.stdout.write('%s\n' % self.mroutes[key])
            for key in sorted(self.blist_mroutes):
                sys.stdout.write('%s\n' % self.blist_mroutes[key])
            return

        for key in filter_mroutes:
            mroute = self.find_mroute(key)
            if mroute:
                sys.stdout.write('%s\n' % mroute)
