#!/usr/bin/env python
# Copyright 2016-2020 Cumulus Networks, Inc.
#
# mroute-check asic backend helper

import sys

try:
    from cumulus.mlx import mlx_open_connection
    from cumulus.mlx import mlx_close_connection
    from cumulus.mlx import mlx_get_mc_route
    from cumulus import mroute_check_util
except ImportError:
    raise NotImplementedError

class Hmroute:
    def __init__(self, sip, dip, iif, erifs, vrf):
        self.vrf = vrf
        self.sip = sip
        self.dip = dip
        self.iif = iif
        self.erifs = erifs

    def __str__(self):
        return ('%s, (%s, %s), Iif: %d' % (self.vrf, self.sip, self.dip, self.iif))

class Asic:
    def __init__(self):
        # key '(a.b.c.d, a.b.c.d)'
        self.desc = "asic"
        self.vrf_db = mroute_check_util.VrfDb()
        mlx_open_connection()
        self.mroutes = {}
        self._get_mlx_mroute()
        mlx_close_connection()

    def _get_mlx_mroute(self):
        """
        mcroute_table format:
        0.0.0.0/239.1.1.3 {'egress_rifs': [36L, 4L], 'iif': 5L, 'vrid': 0, 'egress_rif_cnt': 2}
        """
        mcroute_table = mlx_get_mc_route()
        for mc in mcroute_table.items():
            src, grp = mc[0].split()[1].split('/')
            erifs = []
            vrf = self.vrf_db.vrf_id2name(int(mc[1].get('vrid')))
            iif = mc[1].get('iif')
            erifs = mc[1].get('egress_rifs')
            mroute = Hmroute(src, grp, iif, erifs, vrf)
            self.mroutes['%s, (%s,%s)' % (vrf, src, grp)] = mroute

    def __str__(self):
        ipmr_str = 'Asic ip mroutes:\n'
        for key in sorted(self.mroutes):
            ipmr_str += self._dump_one_mroute(self.mroutes[key])
        return ipmr_str

    def _dump_one_mroute(self, mroute):
        mroute_str = '%s, Oifs: %s\n' % (mroute, mroute.erifs)
        return mroute_str

    def mroute_eq(self, key, iif, oil, ul_mcast):
        # todo - need to get linux interface name from RIF ID to compare
        return True

    def get_mroutes(self):
        # return mroute keys as a set
        return set(self.mroutes)

    def dump_mroutes(self, filter_mroutes=None):
        if not filter_mroutes:
            for key in sorted(self.mroutes):
                sys.stdout.write('%s\n' % self._dump_one_mroute(self.mroutes[key]))
            return

        for key in filter_mroutes:
            mroute = self.mroutes.get(key)
            if mroute:
                sys.stdout.write('%s\n' % self._dump_one_mroute(self.mroutes[key]))
