#!/usr/bin/env python
# Copyright 2016-2020 Cumulus Networks, Inc
#
# mroute-check: validates asic mroutes against kernel mroutes

import shlex
import subprocess
import argparse
import os
import sys
import json
import cumulus.mroute_check_kern
import cumulus.mroute_check_frr
import cumulus.mroute_check_ul_mcast
import cumulus.platforms
import re


FRR = 'FRR'
KERNEL = 'KERNEL'
HARDWARE = 'HARDWARE'
BLACK_LISTED = 'BLACK_LISTED'


def run(command, ignore_return_code=False):
    # Split the command
    command = shlex.split(command)
    try:
        # Using devnull for stderr as some errors are expected
        output = subprocess.check_output(command, stderr=subprocess.STDOUT)
        return str(output)
    except subprocess.CalledProcessError as e:
        # Means the command ran failed, and will have some output, return that output if ignore_return_code is True
        if ignore_return_code:
            return str(e.output)
        else:
            raise e
    except Exception as e:
        raise e


def add_to_mroute_dict(mroute_dict, vrf, grp, source, from_dict):
    if vrf not in mroute_dict:
        mroute_dict[vrf] = {}

    if grp not in mroute_dict[vrf]:
        mroute_dict[vrf][grp] = {}

    if source not in mroute_dict[vrf][grp]:
        mroute_dict[vrf][grp][source] = {}
        mroute_dict[vrf][grp][source][FRR] = False
        mroute_dict[vrf][grp][source][KERNEL] = False
        mroute_dict[vrf][grp][source][HARDWARE] = False
        mroute_dict[vrf][grp][source][BLACK_LISTED] = False

    mroute_dict[vrf][grp][source][from_dict] = True

    return mroute_dict


def build_mroute_map(frr_mroutes, kernel_mroutes, hardware_mroutes, blacklisted_mroutes):
    tmp_mroute_map = {}

    key_pattern = re.compile("(?P<vrf>\S+), \((?P<source>\S+),(?P<grp>\S+)\)")
    for kernel_mroute in kernel_mroutes:
        match = key_pattern.match(kernel_mroute)
        vrf = match.group('vrf')
        source = match.group('source')
        grp = match.group('grp')

        tmp_mroute_map = add_to_mroute_dict(tmp_mroute_map, vrf, grp, source, KERNEL)

    for frr_mroute in frr_mroutes:
        match = key_pattern.match(frr_mroute)
        vrf = match.group('vrf')
        source = match.group('source')
        grp = match.group('grp')

        tmp_mroute_map = add_to_mroute_dict(tmp_mroute_map, vrf, grp, source, FRR)

    for hardware_mroute in hardware_mroutes:
        match = key_pattern.match(hardware_mroute)
        vrf = match.group('vrf')
        source = match.group('source')
        grp = match.group('grp')

        tmp_mroute_map = add_to_mroute_dict(tmp_mroute_map, vrf, grp, source, HARDWARE)

    for blacklisted_mroute in blacklisted_mroutes:
        match = key_pattern.match(blacklisted_mroute)
        vrf = match.group('vrf')
        source = match.group('source')
        grp = match.group('grp')

        tmp_mroute_map = add_to_mroute_dict(tmp_mroute_map, vrf, grp, source, BLACK_LISTED)

    return tmp_mroute_map


def print_mismatch_table(error_dict, basic_table=False):
    # If requested, print a human readable table of results
    class Table(object):

        def __init__(self, table_keys):
            self._table_index = None

            self.table = {}
            self.table[self.table_index] = table_keys
            self.table_column_size = [len(key) + 2 for key in table_keys]
            self.total_keys = len(table_keys)

        @property
        def table_index(self):
            if self._table_index is None:
                self._table_index = 0
            else:
                self._table_index += 1
            return self._table_index

        def add_to_table(self, table_line):
            for i in range(0, len(table_line)):
                # Make sure each entry is formatted correctly
                if isinstance(table_line[i], list):
                    table_line[i] = ','.join(table_line[i])
                if table_line[i] == '':
                    table_line[i] = 'None'

            self.table[self.table_index] = table_line

            # Do we need to update column size?
            for i in range(0, self.total_keys):
                cur_size = self.table_column_size[i]
                other_size = len(table_line[i]) + 2

                if other_size > cur_size:
                    self.table_column_size[i] = other_size

        def print_table(self):
            for _, entry in self.table.items():
                table_line = ''
                for i in range(0, self.total_keys):
                    table_line = table_line + '{str:{width}s} '.format(width=self.table_column_size[i], str=entry[i])
                print(table_line)

    # Can only really print a table for two types of errors. RPT and FRR/Quagga Kernel mismatch
    # The top table is more just a list of failures
    if error_dict['failed_present_check'] or error_dict['kernel_asic_differ'] or basic_table:
        table = Table(['Key', 'ERROR'])

        for error_type, entry_errors in error_dict.items():
            for entry_error in entry_errors:
                src_grp = entry_error['src_grp']
                reason = entry_error['readable']
                table.add_to_table([src_grp, reason])

        table.print_table()

    elif error_dict['frr_kernel_differ']:
        table = Table(['Key', 'FRR IIF', 'FRR OIL', 'KERN IIF', 'KERN OIL'])

        # Can not print a table, no errors
        for entry_error in error_dict['frr_kernel_differ']:
            src_grp = entry_error['src_grp']
            kern_iif = entry_error['kern_iif']
            kern_oil = entry_error['kern_oil']
            frr_iif = entry_error['frr_iif']
            frr_oil = entry_error['frr_oil']

            table.add_to_table([src_grp, frr_iif, frr_oil, kern_iif, kern_oil])

        table.print_table()

    elif error_dict['star_g_s_g_oil_mismatch']:
        table = Table(['Key', 'FRR IIF', 'FRR OIL', 'CORRECT OIL'])
        # Can not print a table, no errors
        for entry_error in error_dict['star_g_s_g_oil_mismatch']:
            src_grp = entry_error['src_grp']
            frr_oil = entry_error['frr_oil']
            rpt_oil = entry_error['rpt_joined_accounted_oil']
            frr_iif = entry_error['frr_iif']

            table.add_to_table([src_grp, frr_iif, frr_oil, rpt_oil])

        table.print_table()

def build_mroute_key(vrf, sip, dip):
    return '{0}, ({1},{2})'.format(vrf, sip, dip)

def run_per_table_checks(vrf, vrf_info, lo_addrs, mlag_params, errors):
    rc = 0
    for grp_addr, grp_info in vrf_info.items():
        # May not have one at all, function checks for us, if not, no S,G OIL check occurs
        src_grp_key = build_mroute_key(vrf, '0.0.0.0', grp_addr)
        _, frr_star_grp_oil = frr.get_mroute_iif_and_oil(src_grp_key)
        try:
            frr_star_grp_oil = [tmp_iface for tmp_iface in frr_star_grp_oil if 'pimreg' not in tmp_iface]
        except Exception:
            # Remove pimreg as this is not a 'real' OIF
            pass

        mlag_anycast_ip = mlag_params[0]
        mlag_local_ip = mlag_params[1]
        mlag_peerlink_rif = mlag_params[2]

        # Need to iterate through all source addrs in this group
        for source_addr, source_info in grp_info.items():
            build_error_string = 'Present in {0} but missing from {1}'
            present_in = []
            not_present = []
            ul_mroute = 0
            src_grp_key = build_mroute_key(vrf, source_addr, grp_addr)

            if frr.is_ul_mroute(src_grp_key):
                ul_mroute = 1

            if source_info[BLACK_LISTED]:
                # Not actually stored into hardware, skip, used for comparing S,G entries OIL anyways
                continue

            for state_dict in ([FRR, KERNEL, HARDWARE] if asic else [FRR, KERNEL]):
                if source_info[state_dict] or ul_mroute:
                    present_in.append(state_dict)
                else:
                    not_present.append(state_dict)

            # If not_present is empty, this entry is fine for all at least containing the entry
            if not_present:
                rc = 1
                build_error_string = build_error_string.format(present_in, not_present)

                tmp_error_dict = {}  # Used to print a table of mismatched IIF OIL
                tmp_error_dict['readable'] = build_error_string
                tmp_error_dict['src_grp'] = src_grp_key
                errors['failed_present_check'].append(tmp_error_dict)
                continue

            # If they are at least all present, start to verify mroute status. First, just make sure all entries
            # IIF and OIL do agree
            kern_iif, kern_oil = kernel.get_mroute_iif_and_oil(src_grp_key)
            if not frr.mroute_eq(src_grp_key, kern_iif, kern_oil):
                rc = 1
                build_error_string = 'Mismatch of FRR/KERNEL IIF/OIL'

                tmp_error_dict = {}  # Used to print a table of mismatched IIF OIL
                tmp_error_dict['readable'] = build_error_string
                tmp_error_dict['src_grp'] = src_grp_key
                tmp_error_dict['kern_iif'] = kern_iif
                tmp_error_dict['kern_oil'] = kern_oil
                tmp_error_dict['frr_iif'], tmp_error_dict['frr_oil'] = frr.get_mroute_iif_and_oil(src_grp_key)
                errors['frr_kernel_differ'].append(tmp_error_dict)

            # If there is an ASIC object to compare
            if asic:
                if not asic.mroute_eq(src_grp_key, kern_iif, kern_oil, ul_mcast):
                    rc = 1
                    build_error_string = 'Mismatch of KERNEL/ASIC IIF/OIL'

                    tmp_error_dict = {}  # Used to print a table of mismatched IIF OIL
                    tmp_error_dict['readable'] = build_error_string
                    tmp_error_dict['src_grp'] = src_grp_key
                    errors['kernel_asic_differ'].append(tmp_error_dict)

            # If there is a *,G group OIL, make sure the S,G entry has all expected OIL
            if frr_star_grp_oil and not args.skip_pim_state:
                frr_iif, frr_oil = frr.get_mroute_iif_and_oil(src_grp_key)
                try:
                    frr_oil.remove("pimreg")
                except ValueError:
                    pass
                frr_star_g_oil_copy = list(frr_star_grp_oil)

                try:
                    frr_star_g_oil_copy.remove(frr_iif)
                except ValueError:
                    # S,G IIF is already not in star_g OIL, okay
                    pass

                # frr mutes inherited ipmr-lo if SIP is local
                if source_addr in lo_addrs:
                    try:
                        frr_star_g_oil_copy.remove("ipmr-lo")
                    except ValueError:
                        pass

                # frr statically adds peerlink_rif into the oiginating
                # mroutes' OIL
                # XXX: this check ignores it instead of enforcing it
                # will fixup
                # in a subsequent commit
                if source_addr == mlag_anycast_ip or\
                    source_addr == mlag_local_ip:
                    try:
                        frr_oil.remove(mlag_peerlink_rif)
                    except ValueError:
                        pass

                # Get any RPT oils we may have for this
                rpt_oils = frr_rpt.get(vrf, {})
                rpt_oils = rpt_oils.get(grp_addr, {})
                rpt_oils = rpt_oils.get(source_addr, [])

                # Get any JOINED oils we may have for this
                join_oils = frr_source_join.get(vrf, {})
                join_oils = join_oils.get(grp_addr, {})
                join_oils = join_oils.get(source_addr, [])

                # Make a minimum set OIL, that our S,G entry NEEDS to have
                rpt_accounted_oil = (set(frr_star_g_oil_copy) | set(join_oils)) - set(rpt_oils)

                if set(frr_oil) != rpt_accounted_oil:
                    rc = 1
                    build_error_string = 'OIL {0} != (*, {1}) with correct OIL: {2} '.format(
                        frr_oil,
                        grp_addr,
                        list(rpt_accounted_oil))

                    tmp_error_dict = {}  # Used to print a table of mismatched IIF OIL
                    tmp_error_dict['readable'] = build_error_string
                    tmp_error_dict['src_grp'] = src_grp_key
                    tmp_error_dict['rpt_joined_accounted_oil'] = list(rpt_accounted_oil)
                    tmp_error_dict['frr_iif'], tmp_error_dict['frr_oil'] = frr.get_mroute_iif_and_oil(src_grp_key)
                    errors['star_g_s_g_oil_mismatch'].append(tmp_error_dict)
    return rc

if __name__ == '__main__':

    # Ensure that the user is root/has sudo
    if os.getuid():
        sys.stderr.write("need to run as root\n")
        sys.exit(1)

    # Check if frr is active
    frr_up = run('systemctl is-active frr', ignore_return_code=True)
    if not frr_up.startswith('active'):
        sys.stderr.write("FRR must be running\n")
        sys.exit(1)

    # command line args
    parser = argparse.ArgumentParser(
        description="Validate FRR/KERNEL/ASIC mroutes and ensure they make sense with regards to routing",
    )

    option = parser.add_mutually_exclusive_group(required=False)
    option.add_argument('-j', '--json', default=False, action='store_true', help='JSON output')
    option.add_argument('-v', '--verbose', default=False, action='store_true', help='Verbose output')
    option.add_argument('-V', '--very-verbose', default=False, action='store_true', help='Very verbose output')
    option.add_argument('-s', '--skip-pim-state', default=False, action='store_true', help='Skip protocol level state check')
    option.add_argument('-t', '--use-mem-table', default=False, action='store_true', help=argparse.SUPPRESS)
    option.add_argument('-d', '--debug', default=False, action='store_true', help='Enable debug logs')
    args = parser.parse_args()

    if args.very_verbose:
        args.verbose = True

    # Build our platform object
    platform_object = cumulus.platforms.probe()
    # create the ASIC object
    asic = None
    ul_mcast = None
    if platform_object.switch:
        chip = platform_object.switch.chip
        if chip.sw_base == 'bcm':
            if chip.dnx:
                import cumulus.mroute_check_bcmdnx
                asic = cumulus.mroute_check_bcmdnx.Asic(args.debug)
            else:
                import cumulus.mroute_check_bcm
                asic = cumulus.mroute_check_bcm.Asic(args.use_mem_table, args.debug)
                # create the underlay multicast routing object
                ul_mcast = cumulus.mroute_check_ul_mcast.Ul_mcast(args.debug)
        elif chip.sw_base == 'mlx':
            import cumulus.mroute_check_mlx
            asic = cumulus.mroute_check_mlx.Asic()

    # create the kernel object
    kernel = cumulus.mroute_check_kern.Kernel(asic is not None)

    # create the routing object
    frr = cumulus.mroute_check_frr.Frr()

    # Get all of the mroute information
    frr_mroutes = frr.get_mroutes()
    kern_mroutes = kernel.get_mroutes()
    kern_blist_mroutes = kernel.get_blist_mroutes()
    if not args.skip_pim_state:
        frr_rpt, frr_source_join = frr.get_source_rpt_and_join_state()
    if asic:
        hw_mroutes = asic.get_mroutes()
    else:
        hw_mroutes = []

    # Build our base mroute map that will be used to compare all mroutes
    all_mroute_map = build_mroute_map(frr_mroutes, kern_mroutes, hw_mroutes, kern_blist_mroutes)

    errors = \
        {
            'failed_present_check': [],
            'frr_kernel_differ': [],
            'kernel_asic_differ': [],
            'star_g_s_g_oil_mismatch': []
        }

    # Loop through all mroutes, and ensure each has the correct criteria
    rc = 0
    lo_addrs = kernel.get_lo_addrs()
    lo_addrs = [net.split("/")[0] for net in lo_addrs]
    mlag_params = frr.get_mlag_params()
    for vrf, vrf_info in all_mroute_map.items():
        if run_per_table_checks(vrf, vrf_info, lo_addrs, mlag_params, errors):
            rc = 1

    if args.json:
        print json.dumps(errors, indent=4)
    elif args.verbose:
        print_mismatch_table(errors)

        if args.very_verbose:
            print('\n')
            # Print out ALL of our dictonaries as well
            print('All IP multicast information')
            print('%s' % frr)
            print('%s' % kernel)
            if asic:
                print('%s' % asic)
    else:
        if rc:
            print_mismatch_table(errors, basic_table=True)

    sys.exit(rc)
