#!/usr/bin/env python
# Copyright 2016,2017,2018,2019,2020 Cumulus Networks, Inc.
#
# mroute-check frr helper

import re
import json
import os
import subprocess
import sys
import shlex

class Qmroute:
    def __init__(self, sip, dip, iif, oil, vrf):
        self.sip = sip
        self.dip = dip
        self.iif = iif
        self.oil = oil
        self.vrf = vrf
        self.is_ulm_orig = 0
        self.is_ulm_term = 0
        self.mismatch = []

    def __str__(self):
        dispStr = '%s, (%s,%s) iif: %s' % (self.vrf, self.sip, self.dip, self.iif)
        if self.oil:
            dispStr += ' oil:%s' % ' '.join(self.oil)
        if self.mismatch:
            dispStr += '\n  errors:'
        for mismatch in self.mismatch:
            dispStr += '\n    %s' % mismatch
        return dispStr

class Frr:
    ORIG_DEV = "lo"
    TERM_DEV = "ipmr-lo"

    def __init__(self):
        # pull the FRR mroute table
        # dict key: '(1.1.1.1, 239.1.1.1)', info: Qmroute
        self.desc = "frr"
        self.mroutes = {}
        self.rpt_dict = {}
        self.joined_dict = {}
        self._mlag_summary_parse()
        self._mroute_parse()

    def __str__(self):
        ipmr_str = 'FRR IP mroutes:\n'
        for key in sorted(self.mroutes):
            ipmr_str += '%s\n' % self.mroutes[key]
        ipmr_str = ipmr_str + '\nFRR S,G JOINED Entries:\n'
        for vrf in sorted(self.joined_dict):
            for grp in sorted(self.joined_dict[vrf]):
                for src in sorted(self.joined_dict[vrf][grp]):
                    ipmr_str += '%s, (%s,%s) JOINED = %s\n' % (vrf, src, grp, self.joined_dict[vrf][grp][src])

        ipmr_str = ipmr_str + '\nFRR S,G RPT Entries:\n'
        for vrf in sorted(self.rpt_dict):
            for grp in sorted(self.rpt_dict[vrf]):
                for src in sorted(self.rpt_dict[vrf][grp]):
                    ipmr_str += '%s, (%s,%s) RPT = %s\n' % (vrf, src, grp, self.rpt_dict[vrf][grp][src])
        return ipmr_str

    @staticmethod
    def _get_cmd_output(cmdList, dispStr=None):
        try:
            with open(os.devnull, 'w') as fnull:
                dispStr = subprocess.check_output(cmdList, stderr=fnull)
        except subprocess.CalledProcessError:
            dispStr = ""
        return dispStr

    def _mlag_summary_parse(self):
        dispStr = "net show pim mlag summary json"
        mlag = self.load_cmd(dispStr, easy_json=True)
        self.mlag_anycast_ip = mlag.get("anycastVtepIp", "0.0.0.0")
        self.mlag_local_ip = mlag.get("localVtepIp", "0.0.0.0")
        self.mlag_peerlink_rif = mlag.get("peerlinkRif", "")
        if self.mlag_peerlink_rif != "" :
            self.ORIG_DEV = self.mlag_peerlink_rif

    def _mroute_parse(self):
        dispStr = "vtysh -c 'show ip mroute vrf all json'"

        # sample format -
        #{
        #  "Default-IP-Routing-Table":{
        #    "239.1.1.1":{
        #      "33.1.1.1":{
        #        "installed":1,
        #        "refCount":2,
        #        "oilSize":1,
        #        "OilInheritedRescan":0,
        #        "iif":"swp2s0",
        #        "oil":{
        #          "swp2s1":{
        #            >>> SNIP >>>
        #            "ttl":1,
        #          }
        #        }
        #      }
        #    }
        #  }
        #}
        ipmr = self.load_cmd(dispStr, easy_json=True)
        for vrf, vrf_info in ipmr.items():
            if vrf == 'Default-IP-Routing-Table':
                vrf = 'default'
            for dip in vrf_info:
                dip_info = vrf_info[dip]
                for sip in dip_info:
                    sip_info = dip_info[sip]
                    if not sip_info.get('installed', 0):
                        continue
                    iif = sip_info.get('iif', 'unk')
                    oil_dict = sip_info.get('oil', {})
                    # XXX: ideally we should use iVifI and oVifI that would also
                    # validate iface info falling out of sync
                    oil = []
                    for oif in oil_dict:
                        oil.append(oif)
                    if sip == '*':
                        sip = '0.0.0.0'
                    mroute = Qmroute(sip, dip, iif, oil, vrf)
                    # check for underlay multicast special routes
                    # XXX - this is the wrong way of establishing orig/term
                    # mroutes
                    if self.ORIG_DEV in iif and sip != '0.0.0.0':
                        mroute.is_ulm_orig = 1
                    elif self.TERM_DEV in oil and sip == '0.0.0.0':
                        mroute.is_ulm_term = 1
                    key = '%s, (%s,%s)' % (vrf, sip, dip)
                    self.mroutes[key] = mroute

    def get_mroutes(self):
        # return a set of mroute keys
        return set(self.mroutes)

    def get_mroute_iif_and_oil(self, key):
        mroute = self.mroutes.get(key)
        if mroute:
            return mroute.iif, mroute.oil
        return None, None

    def is_ul_mroute(self, key):
        mroute = self.mroutes.get(key)
        if mroute:
            return mroute.is_ulm_orig or mroute.is_ulm_term
        return False

    def mroute_eq(self, key, iif, oil):
        mroute = self.mroutes.get(key)
        if not mroute:
            return True

        if mroute.iif != iif:
            mroute.mismatch.append('kernel IIF does not match FRR IIF')
        qoil = set(mroute.oil)
        koil = set(oil)
        # the kernel requires that iif be partof oil for (*,G)
        if iif in koil:
            koil.remove(iif)
        # FRR can also set iif to be in oil for (*,G)
        if mroute.iif in qoil:
            qoil.remove(mroute.iif)
        # FRR removes the termination device and pimreg from the
        # origination mroute
        if mroute.is_ulm_orig:
            try:
                qoil.remove(self.TERM_DEV)
            except KeyError:
                pass
            try:
                qoil.remove("pimreg")
            except KeyError:
                pass

        diff = qoil - koil
        if diff:
            mroute.mismatch.append('FRR has OIFs not in the kernel')
        diff = koil - qoil
        if diff:
            mroute.mismatch.append('kernel has OIFs not in FRR')

        if mroute.mismatch:
            return False
        return True

    def dump_mroutes(self, filter_mroutes=None):
        if not filter_mroutes:
            for key in sorted(self.mroutes):
                sys.stdout.write('%s\n' % self.mroutes[key])
            return

        for key in filter_mroutes:
            mroute = self.mroutes.get(key)
            if mroute:
                sys.stdout.write('%s\n' % mroute)

    def get_mlag_params(self):
        return (self.mlag_anycast_ip, self.mlag_local_ip, self.mlag_peerlink_rif)

    def get_source_rpt_and_join_state(self):
        cmd = 'vtysh -c "show ip pim vrf all join json"'
        joined_state = self.load_cmd(cmd, easy_json=True)

        for vrf, vrf_info in joined_state.items():
            if 'Default' in vrf:
                vrf = 'default'

            self.rpt_dict[vrf] = {}
            self.joined_dict[vrf] = {}

            # RPT state would be shown in pim join
            for tmp_iface, pim_info in vrf_info.items():
                for grp_addr, grp_info in pim_info.items():
                    if not re.search('\d+.\d+.\d+.\d+', grp_addr):
                        continue

                    for source_addr in grp_info:
                        if source_addr == '*':
                            continue

                        source_info = grp_info.get(source_addr)
                        current_join_state = source_info.get('channelJoinName')
                        if current_join_state == 'SGRpt(P)':
                            try:
                                if grp_addr not in self.rpt_dict[vrf]:
                                    self.rpt_dict[vrf][grp_addr] = {}
                                if source_addr not in self.rpt_dict[vrf][grp_addr]:
                                    self.rpt_dict[vrf][grp_addr][source_addr] = []
                                self.rpt_dict[vrf][grp_addr][source_addr].append(tmp_iface)
                            except Exception:
                                pass
                        elif current_join_state == 'JOIN':
                            try:
                                if grp_addr not in self.joined_dict[vrf]:
                                    self.joined_dict[vrf][grp_addr] = {}
                                if source_addr not in self.joined_dict[vrf][grp_addr]:
                                    self.joined_dict[vrf][grp_addr][source_addr] = []
                                self.joined_dict[vrf][grp_addr][source_addr].append(tmp_iface)
                            except Exception:
                                pass
        return self.rpt_dict, self.joined_dict

    def load_cmd(self, cmd, easy_json=False):
        try:
            args = shlex.split(cmd)
            result = subprocess.check_output(args)
        except Exception:
            print('Could not run cmd {0}'.format(args))
            sys.exit(-1)

        if easy_json:
            try:
                result = json.loads(result)
            except ValueError:
                print('Could not load JSON for cmd {0}: {1}'.format(cmd, result))
                sys.exit(-1)

        return result
