#!/usr/bin/env python
# Copyright 2019 Cumulus Networks, Inc.
#
# kernel_nlhelpers -- kernel objects abstractions and helpers
#
# Authors:
#       Andy Roulin, aroulin@cumulusnetworks.com

from sys import stderr
from nlmanager import NetlinkManager, Link, RTM_GETNEIGH, socket, RTM_GETLINK
from nlpacket import Route, Neighbor


class Kernel:
    class Interface:

        @property
        def ifname(self):
            """Get interface name (IFLA_IFNAME)
               Return type: str """
            return self.__ifname

        @property
        def ifindex(self):
            """Get interface index (ifindex)
            Return type: int """
            return self.__ifindex

        @property
        def mac(self):
            """Get interface's L2 MAC address (IFLA_ADDRESS)
            Return type: str """
            return self.__mac

        @property
        def is_admin_up(self):
            """Get interface admin state (IFF_UP/DOWN)
            Return type: bool """
            return self.__admin_up

        @property
        def is_up(self):
            """Get interface state (UP/DOWN)
            Return type: bool """
            return self.__is_up

        @property
        def kind(self):
            """Get interface link kind (IFLA_INFO_KIND)
            Return type: str ("" if no kind) """
            return self.__kind

        @property
        def mtu(self):
            """Get interface MTU value (IFLA_MTU)
            return type: int """
            return self.__mtu

        @property
        def bridge(self):
            """Get the bridge the interface is enslaved to
            Return type: Kernel.Interface (the bridge)
                         None if not enslaved to a bridge """
            if self.__linkinfo is not None:
                if Link.IFLA_INFO_SLAVE_KIND in self.__linkinfo:
                    if "bridge" == self.__linkinfo.get(Link.IFLA_INFO_SLAVE_KIND):
                        master_ifindex = self.ifobject.attributes[Link.IFLA_MASTER].value
                        return [iface for ifname, iface in list(self.__kernel.ifaces.items())
                                if iface.ifindex == master_ifindex][0]
            return None

        @property
        def vrf(self):
            """Get the VRF the interface belongs to
            Return type: Kernel.Interface (the VRF)
                         None if not associated to a VRF"""
            if self.__linkinfo is not None:
                if Link.IFLA_INFO_SLAVE_KIND in self.__linkinfo:
                    if "vrf" == self.__linkinfo.get(Link.IFLA_INFO_SLAVE_KIND):
                        vrf_ifindex = self.ifobject.attributes[Link.IFLA_MASTER].value
                        return [iface for ifname, iface in list(self.__kernel.ifaces.items())
                                if iface.ifindex == vrf_ifindex][0]
            return None

        @property
        def vids(self):
            """ Get VLANs the interface belongs to (VLAN ids)
            return type: list(int) """
            return self.__vids

        @property
        def ifobject(self):
            """ Get the underlying nlmanager link object
            return type: nlmanager.nlpacket.Link """
            return self.__ifobject

        def __init__(self, ifobject, kernel):
            self.__kernel = kernel
            self.__ifname = ifobject.attributes[Link.IFLA_IFNAME].get_pretty_value(str)
            self.__ifindex = ifobject.ifindex
            self.__mac = ifobject.attributes[Link.IFLA_ADDRESS].get_pretty_value(str)
            self.__admin_up = ifobject.is_up()
            self.__is_up = ifobject.attributes[Link.IFLA_OPERSTATE].value != 0

            self.__linkinfo = None
            if Link.IFLA_LINKINFO in ifobject.attributes:
                self.__linkinfo = ifobject.attributes[Link.IFLA_LINKINFO].value

            self.__kind = ""
            if self.__linkinfo is not None:
                if Link.IFLA_INFO_KIND in self.__linkinfo:
                    self.__kind = ifobject.attributes[Link.IFLA_LINKINFO].value.get(Link.IFLA_INFO_KIND)

            self.__mtu = ifobject.get_attribute_value(Link.IFLA_MTU)
            self.__slave_kind = None
            self.__vids = None
            self.__ifobject = ifobject

    class Neighbor:
        @property
        def ip(self):
            """
            :return: the Neighbor's IP address as string
            """
            return "%s" % (self.__neighobj.get_attribute_value(Neighbor.NDA_DST))

        @property
        def mac(self):
            """
            :return: the Neighbor's mac address (LLADDR) as string
            e.g., "3333.FF46.1699"
            """
            return self.__neighobj.get_attribute_value(Neighbor.NDA_LLADDR)

        @property
        def svi(self):
            """
            :return: the Neighbor's SVI as string
            e.g., "eth0"
            """
            return [ifname for ifname, iface in list(self.__kernel.ifaces.items())
                    if iface.ifindex == self.__neighobj.ifindex][0]

        @property
        def state(self):
            """
            :return: the Neighbor's current state as string
            e.g., "NUD_STALE"
            """
            return self.__neighobj.get_state_string(self.__neighobj.state)

	@property
	def vlan(self):
	    return self.__neighobj.get_attribute_value(Neighbor.NDA_VLAN)

	@property
	def master(self):
	    return self.__neighobj.get_attribute_value(Neighbor.NDA_MASTER)

	@property
	def vni(self):
	    return self.__neighobj.get_attribute_value(Neighbor.NDA_VNI)

	@property
	def iface(self):
	    return self.__neighobj.ifindex

        def __init__(self, neighobject, kernel):
            self.__neighobj = neighobject
            self.__kernel = kernel

    class Route:
        @property
        def dst_ip(self):
            """
            :return: the route's destination IP and prefix as string, or "" if none
            IPv4 example: "10.0.0.0/22"
            IPv6 example: "fe80::3e2c:30ff:fe4b:b00/128"
            """
            if Route.RTA_DST not in self.__rtobj.attributes:
                return ""
            return "%s/%d" % (self.__rtobj.get_attribute_value(Route.RTA_DST), self.__rtobj.src_len)

        @property
        def gateway(self):
            """
            :return: the gateway IP as string or None
            IPv4 example: "10.0.0.1"
            """
            if Route.RTA_GATEWAY not in self.__rtobj.attributes:
                return None
            return str(self.__rtobj.attributes[Route.RTA_GATEWAY].value)

        @property
        def output_iface(self):
            """
            :return: the route's output interface as string
            e.g., "eth0"
            """
            return [ifname for ifname, iface in list(self.__kernel.ifaces.items())
                    if iface.ifindex == self.__rtobj.attributes[Route.RTA_OIF].value][0]

        @property
        def vrf(self):
            """
            :return: the vrf associated with the route as string or None
            e.g., "mgmt" or "vrf1"
            """
            tbl_id = self.__rtobj.attributes[Route.RTA_TABLE].value
            for (vrfname, vrf) in list(self.__kernel.vrfs.items()):
                if tbl_id == vrf.tbl_id:
                    return vrfname
            return None

	@property
	def get_nhs(self):
	    return self.__rtobj.get_nexthops()

        def __init__(self, rtobject, kernel):
            self.__rtobj = rtobject
            self.__kernel = kernel

    class Vrf:
        def __init__(self, iface):
            assert(iface.kind == "vrf")
            self.__iface = iface
            linkinfo = iface.ifobject.attributes[Link.IFLA_LINKINFO].value
            linkdata = linkinfo.get(Link.IFLA_INFO_DATA)
            self.__tbl_id = linkdata.get(Link.IFLA_VRF_TABLE)

        @property
        def ifname(self):
            """
            :return: vrf's interface name
            """
            return self.__iface.ifname

        @property
        def tbl_id(self):
            """
            :return: vrf's table ID
            """
            return self.__tbl_id

        @property
        def iface(self):
            """
            :return: the vrf's associated Interface
            """
            return self.__iface

    def __init__(self):
        self.nlmanager = NetlinkManager()
        # get all links in {ifname: Interface} dictionary
        self.__ifaces = {self.__get_ifname(link_object): Kernel.Interface(link_object, self)
                       for link_object in self.nlmanager.link_dump()}
        [self.__update_vids(iface) for ifname, iface in list(self.ifaces.items())]

        # get all vrfs in {vrfname: Vrf} dictionary
        self.__vrfs = {ifname: Kernel.Vrf(iface)
                       for (ifname, iface) in list(self.__ifaces.items()) if iface.kind == "vrf"}

        self.__routes = {self.__get_dst_ip(rtobject): Kernel.Route(rtobject, self)
                       for rtobject in self.nlmanager.routes_dump()}

        # get all routes in {dst_ip: List(Route)} dictionary
        self.__routes = {}
        for rtobject in self.nlmanager.routes_dump():
            key = self.__get_dst_ip(rtobject)
            value = Kernel.Route(rtobject, self)
            if key in self.__routes:
                self.__routes[key].append(value)
            else:
                self.__routes[key] = [value]

        # get all neighbors in {ip: List(Neighbor)} dictionary
        self.__neighbors = {}
        for neighobject in self.nlmanager.request_dump(RTM_GETNEIGH, socket.AF_UNSPEC, False):
            key = self.__get_ip(neighobject)
            value = Kernel.Neighbor(neighobject, self)
            if key in self.__neighbors:
                self.__neighbors[key].append(value)
            else:
                self.__neighbors[key] = [value]

        self.bridges = self.nlmanager.request_dump(RTM_GETLINK, socket.AF_BRIDGE, False)
        self.fdb_entries = self.nlmanager.request_dump(RTM_GETNEIGH, socket.AF_BRIDGE, False)
        pass

    @property
    def ifaces(self):
        """
        :return: dictionary {ifname: Interface}
        """
        return self.__ifaces

    @property
    def routes(self):
        """
        :return: dictionary {dst_ip: List(Route)}
        """
        return self.__routes

    @property
    def neighbors(self):
        """
        :return: dictionary {ip: List(Neighbor)}
        """
        return self.__neighbors

    @property
    def vrfs(self):
        """
        :return: dictionary {vrf_ifname: Vrf}
        """
        return self.__vrfs

    def __get_ifname(self, ifobject):
        return ifobject.attributes[Link.IFLA_IFNAME].get_pretty_value(str)

    def __get_dst_ip(self, rtobject):
        if Route.RTA_DST not in rtobject.attributes:
            return ""
        return "%s/%d" % (rtobject.get_attribute_value(Route.RTA_DST), rtobject.src_len)

    def __get_ip(self, neighobject):
        return "%s" % (neighobject.get_attribute_value(Neighbor.NDA_DST))

    def __vlan_get(self, filter_ifindex=None, filter_vlanid=None, compress_vlans=None):
        return self.nlmanager.vlan_get(filter_ifindex, filter_vlanid, compress_vlans)

    def __update_vids(self, iface):
        iface.vids = self.get_l2_vlans(iface.ifname)

    def get_pvid(self, ifname):
	vlanInfoDict = self.__vlan_get([self.ifaces[ifname].ifindex])
	vlanInfo = vlanInfoDict[ifname]["vlans"]
	for vlan_id, vlan_flag in  vlanInfo:
	    if (vlan_flag & Link.BRIDGE_VLAN_INFO_PVID):
		pvid = vlan_id
		break
	return pvid

    def get_l2_vlans(self, ifname):
        """Get VLANs the interface belongs to (VLAN ids)
        Return type: list(int) """
        assert ifname in self.ifaces, ifname + ": unknown interface"

        iface_vlans = self.__vlan_get([self.ifaces[ifname].ifindex])
        range_begin_vlan_id = None
        range_flag = None
        vids = []

        for (iface_name, vlan_dict) in sorted(iface_vlans.items()):
            assert ifname == iface_name
            vlan_tuples = vlan_dict['vlans']
            for (vlan_id, vlan_flag) in sorted(vlan_tuples):
                if vlan_flag & Link.BRIDGE_VLAN_INFO_RANGE_BEGIN:
                    range_begin_vlan_id = vlan_id
                    range_flag = vlan_flag

                elif vlan_flag & Link.BRIDGE_VLAN_INFO_RANGE_END:
                    range_flag |= vlan_flag

                    if not range_begin_vlan_id:
                        stderr.write("Netlink Error: BRIDGE_VLAN_INFO_RANGE_END is %d but we never saw a "
                                     "BRIDGE_VLAN_INFO_RANGE_BEGIN\n " % vlan_id)
                        range_begin_vlan_id = vlan_id

                    for vlan in range(range_begin_vlan_id, vlan_id + 1):
                        vids.append(vlan_id)

                    range_begin_vlan_id = None
                    range_flag = None

                else:
                    vids.append(vlan_id)

        return vids

    def get_l3_vlan_dev(self, vid):
        """Get L3 VLAN device associated with given VLAN id
        Return type: Kernel.Interface
                     None if there is no such interface """
        for ifname, iface in list(self.ifaces.items()):
            if iface.kind == "vlan":
                linkinfo = iface.ifobject.attributes[Link.IFLA_LINKINFO]
                if Link.IFLA_INFO_DATA in linkinfo.value and \
                        linkinfo.value.get(Link.IFLA_INFO_DATA)[Link.IFLA_VLAN_ID] == vid:
                    return iface
        return None

    def get_svi_vid(self, svi_iface):
        """
        Get the VLAN ID of a SVI
        :param svi_iface Interface svi object
        :return: vlan id or None
        """
        if svi_iface.kind != "vlan":
            raise KeyError(svi_iface + " is not a vlan interface")

        linkinfo = svi_iface.ifobject.attributes[Link.IFLA_LINKINFO]
        if Link.IFLA_INFO_DATA not in linkinfo.value:
            return None

        return linkinfo.value.get(Link.IFLA_INFO_DATA)[Link.IFLA_VLAN_ID]

    def get_vxlan_vni(self, vxlan_iface):
        """
        Get the vni from a vxlan Interface
        :param vxlan_iface: vxlan Interface object
        :return: vni or None
        """
        if vxlan_iface.kind != "vxlan":
            raise KeyError(vxlan_iface.ifname + " is not a vxlan interface")

        linkinfo = vxlan_iface.ifobject.attributes[Link.IFLA_LINKINFO]
        if Link.IFLA_INFO_DATA not in linkinfo.value:
            return None

        return linkinfo.value.get(Link.IFLA_INFO_DATA)[Link.IFLA_VXLAN_ID]

    def get_vni_vxlan(self, vni):
        """
        Get the vxlan Interface from a vni
        :param vni: the vxlan's vni
        :return: vxlan Interface
        """
        for ifname, iface in list(self.ifaces.items()):
            if iface.kind == "vxlan":
                linkinfo = iface.ifobject.attributes[Link.IFLA_LINKINFO]
                if Link.IFLA_INFO_DATA in linkinfo.value and \
                        linkinfo.value.get(Link.IFLA_INFO_DATA)[Link.IFLA_VXLAN_ID] == vni:
                    return iface
        return None
