#!/usr/bin/python
# Copyright 2012,2013,2014,2015,2016,2017,2018,2019,2020 Cumulus Networks, Inc.
#
# ecmpcalc
#
# Utility to calculate the egress port for a given frame and ECMP hash
# configuration.
#
# Will query hardware for hash configuration and ECMP groups if executing on a
# Cumulus switch.
#
# See usage (ecmpcalc -h) for instructions
#
import ipaddr
import argparse
import re
import os
import sys
import string
import struct
import socket
import dpkt
import cumulus.porttab
import cumulus.platforms
from cumulus.ecmpcalc_util import modport2interface, asic2linux, asictrunkmbr2linux, linux2modport

########################################################################
# Exceptions
########################################################################
class MissingFrameInformation(RuntimeError):
    pass
class ParseError(RuntimeError):
    pass
class UnsupportedPlatformError(RuntimeError):
    pass

_chip_id_map = { '56845' : 'Trident',
                '56846' : 'Trident+',
                '5685X' : 'Trident2',
                '5696X' : 'Tomahawk',
                '56340' : 'Helix4',
                '5686X' : 'Trident2+',
                '5676X' : 'Maverick',
                '5687X' : 'Trident3',
                '5615X' : 'Hurricane2',
                '5697X' : 'Tomahawk2',
                '56771' : 'Maverick2'}

def chipname(bcm):
    unit_str = bcm.run('show unit')
    unit_re = re.compile('Unit \d+ chip BCM(?P<chip_id>\d+)');
    m = unit_re.match(unit_str);
    if not m:
        raise RuntimeError('could not determine the platform from unit string: %s\n' % unit_str)
    chip_id = m.group('chip_id')

    chip_name = _chip_id_map.get(chip_id)
    if not chip_name:
        chip_family = chip_id[:len(chip_id)-1] + 'X'
        chip_name = _chip_id_map.get(chip_family)
        if not chip_name:
            chip_name = 'Unknown(%s)' % chip_id

    return chip_name

def is_chip_supported(bcm):
    ec_chips_not_supported = ('Helix4', 'Hurricane2')
    chip = chipname(bcm)
    if chip in ec_chips_not_supported:
        return (0, chip)
    return (1, chip)

def is_chip_pkt_trace_supported(chip):
    _pkt_trace_supported = ('Tomahawk', 'Spectrum', 'Tomahawk2')
    if chip in _pkt_trace_supported:
        return True
    return False

########################################################################
# IP Frame Field Objects
########################################################################
class LinuxInterface(str):
    @classmethod
    def from_string(self, string):
        if not string.startswith('swp'):
            raise ParseError('%s is not a valid Linux interface name'
                             ', must start with "swp"' % string)
        return self(string)

class FrameField:
    '''
    Represents a 16-bit field from an IP frame.
    '''
    name = 'Unknown Frame Field'
    def __init__(self, num):
        self.value = num

    @classmethod
    def from_int(self, num):
        return self(num)

    @classmethod
    def from_string(self, string):
        try:
            value = int(string, 0)
        except ValueError:
            raise ParseError('%s is not an int for frame field %s' %
                             (string, self.name))
        return self(value)

    def __str__(self):
        return self.name + ':' + str(self.value)
    __repr__ = __str__

    def asInt(self):
        return self.value

class ProtocolID(FrameField):
    '''
    IPV4 Protocol ID
    '''
    name = 'Protocol ID'
    protocolMap = { 'icmp' : 1,
                    'tcp'  : 6,
                    'udp'  : 17,
                    'icmp6': 58,
                  }

    @property
    def isTCPorUDP(self):
        return self.asInt() in (6, 17)

    @classmethod
    def from_string(self, string):
        if string in self.protocolMap.keys():
            return self(self.protocolMap[string])

        return self.from_int(int(string))

class L4Port(FrameField):
    name = 'L4 Port'
class VLANID(FrameField):
    name = 'VLAN ID'
class HardwarePort(FrameField):
    name = 'Hardware Port'
class HardwareModule(FrameField):
    name = 'Hardware Module'
class HashSeed(FrameField):
    name = 'Hash Seed'
class Ip6FlowLabel(FrameField):
    name = 'Ip6 flow label'

class Frame:
    '''
    Represents an IP frame with Trident-specific attributes
    '''
    field_desc = { 'src'      : 'Source IP Address',
                   'dst'      : 'Destination IP Address',
                   'vid'      : 'VLAN ID',
                   'sport'    : 'Source Layer 4 Port',
                   'dport'    : 'Destination Layer 4 Port',
                   'protocol' : 'Layer 3 Protocol',
                   'sportid'  : 'Source Hardware Port ID',
                   'smodid'   : 'Source Hardware Module ID',
                   'dportid'  : 'Destination Hardware Port ID',
                   'dmodid'   : 'Destination Hardware Module ID',
                   'ip6_flow_label' : 'Ipv6 flow label'
                   }
    def __init__(self, src, dst, vid, sport, dport, protocol, sportid, smodid,
                 dportid, dmodid,ip6_flow_label):
        self.src = src # IPV4 source
        self.dst = dst # IPV4 destination
        self.vid = vid # vlan ID
        self.sport = sport # L4 source port
        self.dport = dport # L4 destination port
        self.protocol = protocol # IP protocol
        self.sportid = sportid # hardware source port
        self.smodid = smodid # hardware source module
        self.dportid = dportid # hardware destination port
        self.dmodid = dmodid # hardware destination module
        self.ip6_flow_label = ip6_flow_label
    # properties to return the high (most significant 16-bits) and low (least
    # significant 16-bits) of the frame's IPV4 address
    @property
    def srcH(self):
        if self.src is None:
            return None
        return self.src.high()
    @property
    def srcL(self):
        if self.src is None:
            return None
        return self.src.low()
    @property
    def dstH(self):
        if self.dst is None:
            return None
        return self.dst.high()
    @property
    def dstL(self):
        if self.dst is None:
            return None
        return self.dst.low()

    def __str__(self):
        validFields = []
        for fieldName in ('src', 'dst', 'vid', 'sport', 'dport', 'protocol',
                          'sportid', 'smodid', 'dportid', 'dmodid','ip6_flow_label'):
            fieldValue = getattr(self, fieldName)
            if fieldValue is not None:
                validFields.append(fieldName + ':(' + str(fieldValue) + ')')

        return 'IP Frame: ' + ' '.join(validFields)
    __repr__ = __str__

    def _buildEnetHdr(self, smac, dmac, vlan, version):
        # build enet hdr manually - we need to always tag the packet and
        # dpkt currently only constructs untagged enet hdrs
        data = dmac.replace(":", "") + smac.replace(":", "")

        if vlan:
            data += "%04x" % dpkt.ethernet.ETH_TYPE_8021Q
            data += "%04x" % vlan

        if version is 4:
            data += "%04x" % dpkt.ethernet.ETH_TYPE_IP
        else:
            data += "%04x" % dpkt.ethernet.ETH_TYPE_IP6

        return data

    def _buildv6Packet(self, ehdr, ip):
        # build the v6 bytes manually as dpkt module
        # has issues with v6 header
        p = bytes(ip.data)
        s = dpkt.struct.pack('>16s16sxBH', ip.src, ip.dst, ip.nxt, len(p))
        s = dpkt.in_cksum_add(0, s)
        s = dpkt.in_cksum_add(s, p)
        try:
            ip.data.sum = dpkt.in_cksum_done(s)
        except AttributeError:
            pass
        v6 = ip.pack_hdr() + bytes(ip.data)
        pkt = ehdr + v6.encode('hex')
        return pkt

    def buildVisibilityPkt(self, smac, dmac, vid):
        # ip payload
        if self.protocol.value == dpkt.ip.IP_PROTO_UDP:
            ipl = dpkt.udp.UDP(sport=self.sport.value, dport=self.dport.value)
        elif self.protocol.value == dpkt.ip.IP_PROTO_TCP:
            ipl = dpkt.tcp.TCP(sport=self.sport.value, dport=self.dport.value)
        else:
            return None

        # ip pkt
        dstAddr = ipaddr.IPAddress(self.dst.addrstr)
        srcAddr = ipaddr.IPAddress(self.src.addrstr)
        if dstAddr.version is 4:
            ip = dpkt.ip.IP(
                    src=socket.inet_pton(socket.AF_INET, str(srcAddr)),
                    dst=socket.inet_pton(socket.AF_INET, str(dstAddr)),
                    p=self.protocol.value, data=ipl)
            ip.len += len(ipl)
        elif dstAddr.version is 6:
            ip = dpkt.ip6.IP6(
                    src=socket.inet_pton(socket.AF_INET6, str(srcAddr)),
                    dst=socket.inet_pton(socket.AF_INET6, str(dstAddr)),
                    nxt=self.protocol.value, data=ipl, plen=len(ipl),
                    flow=self.ip6_flow_label.asInt(),
                    hlim=255)
        else:
            return None

        # ethernet pkt
        ehdr = self._buildEnetHdr(smac, dmac, vid, dstAddr.version)
        if dstAddr.version is 6:
            pkt = self._buildv6Packet(ehdr, ip)
        else:
            pkt = ehdr + str(ip).encode("hex")

        return pkt

    def visibilityPktMissingFields(self):
        missing = set()
        for fieldName in ('src', 'dst', 'sport', 'dport', 'protocol'):
            fieldValue = getattr(self, fieldName)
            if fieldValue is None:
                missing.add(self.field_desc[fieldName])
        return missing


########################################################################
# IP and Ethernet Addresses
########################################################################
class Address(FrameField):
    def __init__(self, string):
        self.addrstr = string

        if self.validate() is not True:
           raise ParseError('"%s" is not a %s' % (self.addrstr, self.name))

    @classmethod
    def from_string(self, string):
        return self(string)

    def __str__(self):
        return self.name + ':' + self.addrstr
    __repr__ = __str__

    def _checkchars(self, valid):
        for char in self.addrstr:
           if char not in valid:
               raise ParseError('invalid character "%s" in %s: "%s"' %
                                (char, self.name, self.addrstr))
        return True

    def validate(self):
        raise NotImplementedError

    def asInt(self):
        raise NotImplementedError

class MACAddress(Address):
    name = 'MAC Address'
    def validate(self):
        split = self.addrstr.lower().split(':')
        if len(split) != 6:
            return False
        return self._checkchars(string.hexdigits + ':')

class IPAddress(Address):
    name = 'IP Address'

    def __init__(self, string):
        Address.__init__(self, string)
        self.addr = ipaddr.IPAddress(string)

    def validate(self):
        try:
            addr = ipaddr.IPAddress(self.addrstr)
        except ValueError, e:
            raise ParseError(e)

        return True

    def _fold6(self, addrint):
        # IPv6 addresses are "folded" into 32bits: IP[127:96] ^ IP[96:64] ^
        # IP[63:32] ^ IP[31:0]. Some BCM parts support other methods of
        # folding, but this is the default.
        return (((addrint >> 96) & 0xffffffff) ^
                ((addrint >> 64) & 0xffffffff) ^
                ((addrint >> 32) & 0xffffffff) ^
                ((addrint >> 0) & 0xffffffff))

    def asInt(self):
        if self.addr.version is 4:
            return int(self.addr)
        elif self.addr.version is 6:
            return self._fold6(int(self.addr))
        else:
            raise RuntimeError('%s is not a IPv4 or IPv6 address' % str(self.addr))

    def high(self):
        return IPAddressHigh(self.addrstr)

    def low(self):
        return IPAddressLow(self.addrstr)

class IPAddressHigh(IPAddress):
    name = 'IP Address (High 16-bits)'
    def asInt(self):
        val = IPAddress.asInt(self)
        return (val & 0xffff0000) >> 16

class IPAddressLow(IPAddress):
    name = 'IP Address (Low 16-bits)'
    def asInt(self):
        val = IPAddress.asInt(self)
        return val & 0xffff


########################################################################
# RTAG7 Hash Objects
########################################################################
class RTAG7HashBin:
    '''
    RTAG7 Hash Bins (IPV4/IPV6 Frame Fields)
    '''
    bins = ('srcH',
            'srcL',
            'dstH',
            'dstL',
            'vid',
            'sport',
            'dport',
            'protocol',
            'sportid',
            'smodid',
            'dportid',
            'dmodid')
    descriptions = { 'srcH'     : 'Source IP Address',
                     'srcL'     : 'Source IP Address',
                     'dstH'     : 'Destination IP Address',
                     'dstL'     : 'Destination IP Address',
                     'vid'      : 'VLAN ID',
                     'sport'    : 'Source Layer 3 Port',
                     'dport'    : 'Destination Layer 3 Port',
                     'protocol' : 'Layer 3 Protocol',
                     'sportid'  : 'Source Hardware Port ID',
                     'smodid'   : 'Source Hardware Module ID',
                     'dportid'  : 'Destination Hardware Port ID',
                     'dmodid'   : 'Destination Hardware Module ID',
                   }
    width = 2
            
    def __init__(self, name, description=None):
        self.name = name
        if description is not None:
            self.description = description

    @classmethod
    def from_string(self, string):
        if string not in self.bins:
            raise ParseError('"%s" is not a valid hash bin (%s)' %
                             (string, ','.join(self.bins)))
        return self(string, self.descriptions[string])

    def __str__(self):
        return 'RTAG7HashBin: %s' % (self.name)
    __repr__ = __str__

class RTAG7CRC:
    '''
    RTAG7 CRCs
    '''
    def __init__(self, initial=0, poly=0x1021):
       self.initial = initial
       self.crc = initial
       self.poly = poly

    def update(self, data):
       for byte in reversed(data):
          for bit in reversed(range(8)):
             msBit = self.crc & 0x8000
             self.crc = (self.crc << 1) | ((ord(byte) >> (7 - bit)) & 0x01)
             if msBit:
                self.crc ^= self.poly
          self.crc &= 0xffff

    @property
    def value(self):
       final = self.crc
       for bit in reversed(range(16)):
          msBit = final & 0x8000
          final = (final << 1)
          if msBit:
             final ^= self.poly
       return (final ^ 0x0000) & 0xffff

    def reset(self):
        self.crc = self.initial 

class RTAG7Hash:
    '''
    RTAG7 Hash Calculation
    '''
    bmapOffsets = {
        'srcH'     : 11,
        'srcL'     : 10,
        'dstH'     : 9,
        'dstL'     : 8,
        'vid'      : 7,
        'sport'    : 6,
        'dport'    : 5,
        'protocol' : 4,
        'sportid'  : 3,
        'smodid'   : 2,
        'dportid'  : 1,
        'dmodid'   : 0,
    }
    hashFunctions = {
        'crc16-ccitt' : RTAG7CRC(poly=0x1021),
        'crc16-bisync' : RTAG7CRC(poly=0x8005),
    }
    hashInputWidth = 224 / 8
    def __init__(self, func, seed, bins):
        self.func = func
        self.seed = seed
        self.bins = bins

        if func not in self.hashFunctions.keys():
            raise RuntimeError('no such hash function: %s' % func)

        self.crc = self.hashFunctions[func]

        self._calcbmap()

    def _calcbmap(self):
        bmap = 0
        for hbin in self.bins:
            bmap |= (1 << self.bmapOffsets[hbin.name])
        return bmap

    def _assembleHashInput(self, frame):
        hashInput = ['\x00'] * self.hashInputWidth

        hashInput[:4] = struct.pack('!I', self.seed.asInt())

        for hbin in self.bins:
            offset = 4 + (11 - self.bmapOffsets[hbin.name]) * hbin.width
            hashInput[offset:offset+hbin.width] = \
                struct.pack('!H', getattr(frame, hbin.name).asInt())

        return hashInput

    def hashable(self, frame):
        missing = set()
        for hbin in self.bins:
            if getattr(frame, hbin.name) is None:
                missing.add(hbin.description)
        if len(missing):
            raise MissingFrameInformation, missing
        return True

    def calculate(self, frame):
        self.hashable(frame)
        hashInput = self._assembleHashInput(frame)
        self.crc.reset()
        self.crc.update(hashInput)
        return self.crc.value

class RTAG7Config:
    '''
    Hardware RTAG7 configuration
    '''

    bmapOffsets = {
        'srcH'     : 11,
        'srcL'     : 10,
        'dstH'     : 9,
        'dstL'     : 8,
        'vid'      : 7,
        'sport'    : 6,
        'dport'    : 5,
        'protocol' : 4,
        'sportid'  : 3,
        'smodid'   : 2,
        'dportid'  : 1,
        'dmodid'   : 0,
    }
    hashFunctions = { 3 : 'crc16-bisync',
                      9 : 'crc16-ccitt',
                    }
    def __init__(self, bcmshell):
        self.bcm = bcmshell
        self.chip = chipname(self.bcm)

    def load(self):
        for register in ('hash_control',
                         'rtag7_hash_seed_a',
                         'rtag7_ipv4_tcp_udp_hash_field_bmap_2',
                         'rtag7_ipv4_tcp_udp_hash_field_bmap_1',
                         'rtag7_hash_field_bmap_1',
                         'rtag7_hash_field_bmap_1',
                         'rtag7_hash_control_3'):
            setattr(self, register, self.bcm.getreg(register, fields=True))
        if self.chip in ('Trident', 'Trident+'):
            for register in ( 'rtag7_hash_ecmp(0)', 'rtag7_hash_ecmp(1)'):
                setattr(self, register, self.bcm.getreg(register, fields=True))

    def _registersUnderstood(self):
        """
        Make sure we understand the RTAG7/ECMP configuration.  Raise an
        exception when our assumptions are broken.
        """
        if self.hash_control['ECMP_HASH_USE_RTAG7'] != 1:
            raise NotImplementedError('unsupported hardware configuration, '
                                      'hash_control.ECMP_HASH_USE_RTAG7 != 1')
        if (self.rtag7_hash_control_3['HASH_A0_FUNCTION_SELECT']
            not in self.hashFunctions.keys()):
            raise NotImplementedError('unsupported hardware configuration, '
                                      'rtag7_hash_control_3.HASH_A0_FUNCTION_SELECT '
                                      'not in: ' % self.hashFunctions.keys()) 
        if self.chip in ('Trident', 'Trident+') and\
                    (sum(getattr(self, 'rtag7_hash_ecmp(0)').values() +
                    getattr(self, 'rtag7_hash_ecmp(1)').values()) != 0):
            raise NotImplementedError('unsupported hardware configuration, '
                                      'rtag7_hash_ecmp({0,1}) must be zero')

    def getHashSeed(self, frame):
        self._registersUnderstood()
        return self.rtag7_hash_seed_a['HASH_SEED_A']

    def getHashFunction(self, frame):
        self._registersUnderstood()
        return self.hashFunctions[self.rtag7_hash_control_3['HASH_A0_FUNCTION_SELECT']]

    def getHashFields(self, frame):
        self._registersUnderstood()

        if frame.protocol is not None and frame.protocol.isTCPorUDP:
            # dport and sport are used by the hardware to select a hash bin
            # bitmap even if the fields are not part of the hash calculation.
            # Therefore, we require both for tcp and udp frames.
            if frame.sport is None or frame.dport is None:
                parser.error('--sport and --dport required for TCP and UDP frames')

            if frame.sport.value == frame.dport.value:
                bmap = self.rtag7_ipv4_tcp_udp_hash_field_bmap_1[
                       'IPV4_TCP_UDP_SRC_EQ_DST_FIELD_BITMAP_A']
            else:
                bmap = self.rtag7_ipv4_tcp_udp_hash_field_bmap_2[
                       'IPV4_TCP_UDP_FIELD_BITMAP_A']
        else:
            bmap = self.rtag7_hash_field_bmap_1['IPV4_FIELD_BITMAP_A']
        

        hashFields = []
        for field, offset in self.bmapOffsets.items():
           if bmap & (1 << offset):
               hashFields.append(RTAG7HashBin.from_string(field))

        return hashFields


########################################################################
# Egress Objects
########################################################################
class EgressObject:
    def __init__(self, ipv4net, egressId):
        self.ipv4net = ipv4net
        self.egressId = egressId

class EgressTable:
    def __init__(self, bcmshell, verbose=False):
        self.bcm = bcmshell
        self.verbose = verbose

    def findEgressObjectByDst(self, dst):
        if dst.version is 4:
            table = self.bcm.run('l3 defip show')
        elif dst.version is 6:
            table = self.bcm.run('l3 ip6route show')
        else:
            raise RuntimeError('Error: invalid IP version: %s' % dst)

        intf = 0
        net = 0
        for line in table.split('\n'):
            try:
                (foo, foo, netaddr, foo, intfstr, foo) = line.split(None, 5)
                new_net = ipaddr.IPNetwork(netaddr)
                new_intf = int(intfstr)
            except ValueError:
                continue

            if dst in new_net and (intf == 0 or new_net > net):
                net = new_net
                intf = new_intf

        if intf == 0:
            raise RuntimeError('Error: no egress object for destination: %s' % dst)
        return EgressObject(net, intf)


########################################################################
# ECMP Objects
########################################################################
class ECMPConfig:
    egressObjectOffset = 200000
    egressObjectMax = 204095

    def __init__(self, bcmshell, verbose=False):
        self.bcm = bcmshell
        self.verbose = verbose
        self.chip = chipname(self.bcm)

    @classmethod
    def objectIsECMP(self, egressId):
        if egressId >= self.egressObjectOffset and egressId <= self.egressObjectMax:
            return True
        else:
            return False

    def getECMPGroup(self, egressId):
        if not self.objectIsECMP(egressId):
            raise RuntimeError('egress object ID expected to be between %d and %d' %
                               (self.egressObjectOffset, self.egressObjectMax))
        groupOffset = egressId - self.egressObjectOffset
        group = self.bcm.gettable('L3_ECMP_GROUP',
                                  fields=True,
                                  start=groupOffset,
                                  entries=1)[0]
        if self.verbose:
            sys.stderr.write('ecmp group offset %d\n' % groupOffset)
        return group

    def pageLookup(self, egressId, pageNum):
        group = self.getECMPGroup(egressId)
        if self.chip is 'Trident':
            (base, count) = (group['BASE_PTR_%d' % pageNum],
                         group['COUNT_%d' % pageNum] + 1) # implicit +1 on count
        elif self.chip in ('Trident+', 'Trident2', 'Trident2+', 'Trident3', 'Maverick2'):
            (base, count) = (group['BASE_PTR'],
                         group['COUNT'] + 1) # implicit +1 on count
        else:
            raise UnsupportedPlatformError('ecmp page is not available on %s' % self.chip)

        if self.verbose:
            sys.stderr.write('base,count: %d,%d\n' % (base, count))
        return (base,count)

    def getEcmpMask(self):
        mask = {0: 0x3ff,
                1: 0x7ff,
                2: 0xfff,
                3: 0x1fff,
                4: 0x3fff,
                5: 0x7fff,
                6: 0xffff,
                7: 0}

        if self.chip is 'Trident':
            ecmpMask = 0x3ff
        elif self.chip in ('Trident+', 'Trident2', 'Trident2+', 'Trident3', 'Maverick2'):
            hash_control =  self.bcm.getreg('hash_control', fields=True)
            ecmpMask = mask[hash_control['ECMP_HASH_FIELD_UPPER_BITS_COUNT']]
        else:
            raise UnsupportedPlatformError('ecmp mask is not available on %s' % self.chip)
        return ecmpMask

    def getECMP(self, egressId, frameHash):
        ecmpPage = (frameHash & 0x0c00) >> 10
        (ecmpBase, ecmpCount) = self.pageLookup(egressId, ecmpPage)
        ecmpMask = self.getEcmpMask()
        ecmpOffset = (frameHash & ecmpMask) % ecmpCount
        ecmpPtr = ecmpBase + ecmpOffset
        ecmp = self.bcm.gettable('L3_ECMP',
                                 fields=True,
                                 start=ecmpPtr,
                                 entries=1)[0]
        if self.verbose:
            sys.stderr.write('ecmpMask 0x%x L3_ECMP+%d\n' % (ecmpMask, ecmpPtr))
        return ecmp

    def getNextHop(self, egressId, frameHash):
        ecmp = self.getECMP(egressId, frameHash)
        nextHop = self.bcm.gettable('ING_L3_NEXT_HOP',
                                 fields=True,
                                 start=ecmp['NEXT_HOP_INDEX'],
                                 entries=1)[0]
        if self.verbose:
            sys.stderr.write('ING_L3_NEXT_HOP+%d\n' % ecmp['NEXT_HOP_INDEX'])
        return nextHop

    def getPort(self, egressId, frameHash):
        nextHop = self.getNextHop(egressId, frameHash)
        if nextHop['T']:
            #If trunk return first member port
            trunk = self.bcm.gettable('TRUNK_GROUP',
                                 fields=True,
                                 start=nextHop['TGID'],
                                 entries=1)[0]
            trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                 fields=True,
                                 start=trunk['BASE_PTR'],
                                 entries=1)[0]
            modId = trunkMbr['MODULE_ID']
            port = trunkMbr['PORT_NUM']
            isTrunkMbr = True
        else:
            modId = nextHop['MODULE_ID']
            port = nextHop['PORT_NUM']
            isTrunkMbr = False
        return isTrunkMbr, modport2interface(modId, port, self.chip)

    # When resilient hashing is enabled use "bcm hd dest" command to get info
    def getHdPort(self, hd_cmd):
        NHId = 0
        NHObjectOffset = 100000
        table = self.bcm.run(hd_cmd)
        for line in table.splitlines():
            str = line.strip().split(' ')
            if 1 == len(str):
                continue
            if (str[1] in 'destination'):
                NHId = int(str[0])
        if self.verbose:
            sys.stderr.write('hd_cmd %s\n' % hd_cmd)
            sys.stderr.write('NextHopId %d NhIdOffset %d\n' % (NHId, NHObjectOffset))

        if NHId < NHObjectOffset:
            raise RuntimeError('unexpected output from: %s\n %s' % (hd_cmd, table))

        if self.chip in ('Trident2+', 'Trident3', 'Maverick2'):
            # Get NHId to port mapping via "l3 egress show" - avoid reading mem tables
            # Can be reused for furture BCM platforms as well.
            # Sample output of "bcm l3 egress show" is given below
            # root:~# bcm l3 egress show
            # Entry  Mac                 Vlan INTF PORT MOD MPLS_LABEL ToCpu Drop RefCount L3MC
            # 100002  00:00:00:00:00:00 3000 16383     0    0        -1  yes  yes    0   no
            # 100003  00:00:00:00:00:00    0 16383     0    0        -1  yes  yes    9   no
            # 100004  00:02:00:00:00:21 3087   87    87    0        -1   no   no    5   no
            # 100005  00:02:00:00:00:21 3087   87    87    0        -1   no   no    1  yes
            # 100008  00:02:00:00:00:0d 3087   87    87    0        -1   no   no    1   no

            l3_egress_cmd = ('l3 egress show')
            egress_table = self.bcm.run(l3_egress_cmd)
            port = 0

            for line in egress_table.splitlines():
                if self.verbose:
                    sys.stderr.write('%s\n' % line)

                str = line.strip().split()

                # Skip header line
                if (str[0] in 'Entry'):
                    continue

                if (int(str[0]) == NHId):
                    if (type(str[4]) == int):
                        port = int(str[4])
                        modId = int(str[5])
                        isTrunkMbr = False
                    else:
                        # non-numeric port
                        portStr = str[4]
                        if (portStr[-1] == 't'):
                            # trunk port

                            # trunk_port is represented by suffix 't' in PORT column
                            # extract trunk_index by removing suffix 't'
                            # read basePtr = TRUNK_GROUP[trunk_index].BASE_PTR
                            # read trunkMbr = TRUNK_MEMBER[basePtr]
                            # port = trunkMbr[PORT_NUM]

                            # /usr/lib/cumulus/bcmcmd l3 egress show
                            # Entry  Mac                 Vlan INTF PORT MOD MPLS_LABEL ToCpu Drop RefCount L3MC
                            # 100002  00:00:00:00:00:00 3000 14335     0    0        -1   no   no    0   no
                            # 100003  00:00:00:00:00:00 3000 14335     0    0        -1   no   no   23   no
                            # 100004  3c:2c:30:81:39:ba    4    4     0t   0        -1   no   no    2   no
                            # 100005  34:17:eb:f6:00:f5    3    3     1t   0        -1   no   no    2   no

                            # remove trailing 't' from the port
                            trunk_idx = int(portStr[:-1])

                            # dump trunk_group[trunk_idx]
                            trunkGrp = self.bcm.gettable('TRUNK_GROUP',
                                          fields=True,
                                          start=trunk_idx,
                                          entries=1)[0]

                            # dump trunk_member[trunkGrp[BASE_PTR]]
                            trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                          fields=True,
                                          start=trunkGrp['BASE_PTR'],
                                          entries=1)[0]
                            # extract port from the trunk_member
                            port = trunkMbr['PORT_NUM']
                            modId = int(str[5])
                            isTrunkMbr = True
                    break

            if self.verbose:
                sys.stderr.write('NHId %d => Port %d\n' % (NHId, port))

            # Trunk ports are shown as swpX only as trunk information is not available
            return trunkMbr, modport2interface(modId, port, self.chip)
        else:
            NHIdOffset = (NHId - NHObjectOffset)
            nextHop = self.bcm.gettable('ING_L3_NEXT_HOP',
                                     fields=True,
                                     start=NHIdOffset,
                                     entries=1)[0]

            if self.verbose:
                sys.stderr.write('NextHopId %d Port_num %d\n' %
                            (NHId, nextHop['PORT_NUM']))
            if nextHop['T']:
                #If trunk return first member port
                trunk = self.bcm.gettable('TRUNK_GROUP',
                                     fields=True,
                                     start=nextHop['TGID'],
                                     entries=1)[0]
                trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                     fields=True,
                                     start=trunk['BASE_PTR'],
                                     entries=1)[0]
                modId = trunkMbr['MODULE_ID']
                port = trunkMbr['PORT_NUM']
                isTrunkMbr = True
            else:
                modId = nextHop['MODULE_ID']
                port = nextHop['PORT_NUM']
                isTrunkMbr = False

            return isTrunkMbr, modport2interface(modId, port, self.chip)
        
    def getLinuxIntf(self, isTrunkMbr, port):
        if self.verbose:
            sys.stderr.write('bcm interface %s\n' % port)
        return asictrunkmbr2linux(port) if isTrunkMbr else asic2linux(port)

    def getOffset(self, mcount, frameHash):
        return (frameHash & 0x3ff) % mcount
                          

class VisibilityPktTrace:
    def __init__(self, bcmshell, frame, verbose=False):
        self.bcm = bcmshell
        self.verbose = verbose

        # trace source port
        self.pt = cumulus.porttab.porttab()
        self.logical_port = frame.sportid.value
        self.dport_name = self.pt.logical2dport(int(self.logical_port))

        # packet trace results
        self.knownL3UcPkt = False
        self.egressId = -1
        self._runTrace(frame)

    def _getPortVid(self):
        portEnt = self.bcm.gettable('PORT',
                                 fields=True,
                                 start=self.logical_port,
                                 entries=1)[0]
        pvid = portEnt['PORT_VID']
        if self.verbose:
            sys.stderr.write('logical_port %s has pvid %d\n' % (self.logical_port, pvid))
        return pvid

    def _getPortMac(self):
        port_str = self.bcm.run('port %s' % self.dport_name)
        pat = re.compile('Stad\((?P<pmac>\S+)\)')
        obj = pat.search(port_str)
        pmac = obj.group('pmac')
        if self.verbose:
            sys.stderr.write('port %s has pmac %s\n' % (self.dport_name, pmac))
        return pmac

    def _runTrace(self, frame):
        # we need to build a routable ethernet header; to do that -
        # 1. vlan has to be valid i.e. source port has to be a member of the
        #    vlan. Best guess would be PVID.
        # 2. dmac has to be a mystation entry. Best guess would again be the
        #    port mac.
        # 3. trace packet will not trigger mac learning so it is fairly safe
        #    to use any mac as smac.
        pmac = self._getPortMac()
        pvid = self._getPortVid()
        pkt = frame.buildVisibilityPkt(smac=pmac, dmac=pmac, vid=pvid)
        if not pkt:
            raise RuntimeError('Error: visibility pkt trace generation failed')

        if self.verbose:
            sys.stderr.write('Trace pkt:\n  %s\n' % pkt)

        # XXX - Temporarily using the diag shell directly to inject the trace
        # packet into the bcm pipeline; will be transitioned to switchd fuse
        # shortly
        trace_str = self.bcm.run('tx 1 VisibilitySourcePort=%s DATA=%s' % (self.dport_name, pkt))
        # Check if the packet was L3 switched successfully
        pat = re.compile('KnownL3UcPkt')
        obj = pat.search(trace_str)
        if not obj:
            return

        self.knownL3UcPkt = True
        if self.verbose:
            sys.stderr.write('Trace packet was L3 switched\n')

        # Check if the output was ecmp
        pat = re.compile('ecmp_1_egress (?P<egressId>\d+)')
        obj = pat.search(trace_str)
        if obj:
            self.egressId = int(obj.group('egressId'))
            if self.verbose:
                sys.stderr.write('Trace packet was switched via egress %d\n' % self.egressId)
            # remove egress object sdk offset
            sdk_offset = 100000
            if self.egressId < sdk_offset:
                raise RuntimeError('Error: egress id %d format is unexpected' % self.egressId)
            self.egressId = self.egressId - sdk_offset

    def getL3PktRes(self):
        return self.knownL3UcPkt, self.egressId != -1

    def getOutLinuxIntf(self):
        nextHop = self.bcm.gettable('ING_L3_NEXT_HOP',
                                 fields=True,
                                 start=self.egressId,
                                 entries=1)[0]
        if nextHop['T']:
            #If trunk return first member port
            trunk = self.bcm.gettable('TRUNK_GROUP',
                                 fields=True,
                                 start=nextHop['TGID'],
                                 entries=1)[0]
            trunkMbr = self.bcm.gettable('TRUNK_MEMBER',
                                 fields=True,
                                 start=trunk['BASE_PTR'],
                                 entries=1)[0]
            port = trunkMbr['PORT_NUM']
            isTrunkMbr = True
        else:
            port = nextHop['PORT_NUM']
            isTrunkMbr = False

        port = self.pt.logical2dport(port)
        return asictrunkmbr2linux(port) if isTrunkMbr else asic2linux(port)


########################################################################
#
# MAIN
#
# Parses arguments, reads hardware values, calculates egress interface
#
########################################################################
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Determine the ECMP egress interface for a frame')
    parser.add_argument('-v', '--verbose',
                        required=False,
                        action='store_true',
                        help='Verbose output')
    parser.add_argument('-p', '--protocol',
                        required=False,
                        type=ProtocolID.from_string,
                        help='IP protocol ("tcp", "udp", or integer)')
    parser.add_argument('-s', '--src',
                        required=False,
                        type=IPAddress.from_string,
                        help='Source IPv4 or IPv6 address')
    parser.add_argument('--sport',
                        required=False,
                        type=L4Port.from_string,
                        help='Source L4 port')
    parser.add_argument('-d', '--dst',
                        required=False,
                        type=IPAddress.from_string,
                        help='Destination IPv4 or IPv6 address')
    parser.add_argument('--dport',
                        required=False,
                        type=L4Port.from_string,
                        help='Destination L4 port')
    parser.add_argument('--vid',
                        required=False,
                        type=VLANID.from_string,
                        help='VLAN ID')
    parser.add_argument('-i', '--in-interface',
                        required=False,
                        type=LinuxInterface.from_string,
                        help='Input interface')
    parser.add_argument('--sportid',
                        required=False,
                        type=HardwarePort.from_string,
                        help='Hardware source port ID')
    parser.add_argument('--smodid',
                        required=False,
                        type=HardwareModule.from_string,
                        help='Hardware source module ID')
    parser.add_argument('-o', '--out-interface',
                        required=False,
                        type=LinuxInterface.from_string,
                        help='Output interface')
    parser.add_argument('--dportid',
                        required=False,
                        type=HardwarePort.from_string,
                        help='Hardware destination port ID')
    parser.add_argument('--dmodid',
                        required=False,
                        type=HardwareModule.from_string,
                        help='Hardware destination module ID')
    parser.add_argument('--ip6_flow_label',
                        required = False,
                        type=Ip6FlowLabel.from_string,
                        default="0",
                        help='Ipv6 flow label')
    # by default we try to probe the hardware
    parser.add_argument('--hardware',
                        required=False,
                        default=True,
                        action='store_true',
                        help='Force hardware query, ecmpcalc will read the '
                             'hash configuration from hardware')
    parser.add_argument('--nohardware',
                        required=False,
                        dest='hardware',
                        action='store_false',
                        help='Force no hardware query')
    parser.add_argument('-hs', '--hashseed',
                        required=False,
                        type=HashSeed.from_string,
                        help='RTAG7 hash seed')
    parser.add_argument('-hf', '--hashfields',
                        required=False,
                        nargs='+',
                        type=RTAG7HashBin.from_string,
                        help='list of RTAG7 hash fields, one or more of (%s)' %
                             ', '.join(RTAG7HashBin.bins))
    parser.add_argument('--hashfunction',
                        required=False,
                        choices=('crc16-ccitt', 'crc16-bisync'),
                        help='RTAG7 hash function')
    parser.add_argument('-e', '--egress',
                        required=False,
                        type=int,
                        help='multipath egress object as from "/usr/lib/cumulus/bcmcmd l3 defip show"')
    parser.add_argument('-c', '--mcount',
                        required=False,
                        type=int,
                        help='ECMP group member count')

    try:
        args = parser.parse_args()
    except ParseError, e:
        parser.error(str(e))
  
    if (os.geteuid() != 0):
        sys.stderr.write('root privileges are needed to run ecmpcalc\n')
        sys.exit(-1)

    platform_object = cumulus.platforms.probe()
    if platform_object.switch.chip.sw_base == 'mlx':
        mlx = True
        try:
            import cumulus.mlx_ecmpcalc
            asicShellAvailable = True
        except ImportError:
            asicShellAvailable = False
    else:
        mlx = False
        try:
            import bcmshell
            asicShellAvailable = True
        except ImportError:
            asicShellAvailable = False

    # Determine whether we will be querying the hardware.
    if args.hardware:
        if not asicShellAvailable:
            sys.stderr.write('ecmpcalc: asicTool not available, unable to query hardware\n')
            sys.exit(-1)

    if mlx:
        chip = "Spectrum"
        if not args.hardware:
            sys.stderr.write('ecmpcalc: only hardware mode supported on mellanox platforms\n')
            sys.exit(-1)
        if not args.in_interface:
            sys.stderr.write('ecmpcalc: --in-interface is required on mellanox platforms\n')
            sys.exit(-1)
    else:
        if args.hardware:
            # check if we can actually connect to the switchd socket
            try:
                bcmshell.bcmshell().run('echo foo')
            except IOError as e:
                sys.stderr.write(('%s\n' % e))
                sys.exit(-1)

            (rv, chip) = is_chip_supported(bcmshell.bcmshell())
            if rv == 0:
                sys.stderr.write('ecmpcalc is not supported on this chip (%s)\n' % chip)
                sys.exit(-1)

    if args.hardware:
        sys.stderr.write('ecmpcalc: will query hardware\n')
    else:
        sys.stderr.write('ecmpcalc: will NOT query hardware\n')

    # Convert interfaces to port/mod.
    if args.in_interface and (args.sportid is not None or
                              args.smodid is not None):
        parser.error('require only one of --in-interface OR (--sportid/--smodid)')
    if args.out_interface and (args.dportid is not None or
                               args.dmodid is not None):
        parser.error('require only one of --out-interface OR (--dportid/--dmodid)')

    if args.hardware is False and (args.in_interface is not None or
                                   args.out_interface is not None):
        parser.error('unable to query hardware for interface translation')

    if args.in_interface is not None:
        (mod, port) = linux2modport(args.in_interface, chip)
        args.smodid = HardwareModule.from_string(str(mod))
        args.sportid = HardwareModule.from_string(str(port))
        #For Trident3 Alone we need pass interface name for
        #Input port, it doesn't accept port id.
        if chip in ('Trident3', 'Maverick2'):
            port_tab = cumulus.porttab.porttab()
            args.sportid.value = port_tab.logical2dport(port)

    if args.out_interface is not None:
        (mod, port) = linux2modport(args.out_interface, chip)
        args.dmodid = HardwareModule.from_string(str(mod))
        args.dportid = HardwareModule.from_string(str(port))

    # Build the frame.
    frame = Frame(args.src, args.dst, args.vid, args.sport, args.dport,
                  args.protocol, args.sportid, args.smodid, args.dportid,
                  args.dmodid,args.ip6_flow_label)
    if args.verbose:
        sys.stderr.write('frame: %s\n' % frame)

    if args.hardware and is_chip_pkt_trace_supported(chip):
        # packet trace is not possible without input interface
        if not args.sportid:
            parser.error('require --in-interface OR (--sportid/--smodid)')
        missing = frame.visibilityPktMissingFields()
        if missing:
            parser.error('Frame information is incomplete.\n'
                     'Please specify additional options for: %s' % ', '.join(missing))
        if args.verbose:
            sys.stderr.write('using packet trace...\n')

        if mlx:
            linux_intf = cumulus.mlx_ecmpcalc.getEcmpMbrLinuxIntf(args.in_interface, frame, args.verbose)
        else:
            pktTrace = VisibilityPktTrace(bcmshell.bcmshell(), frame,
                                  verbose=args.verbose)
            isL3, isEcmp = pktTrace.getL3PktRes()
            if not isL3:
                raise RuntimeError('Error: traffic to %s will not be L3 switched' % args.dst.addrstr)
            if not isEcmp:
                raise RuntimeError('Error: traffic to %s will not ECMP' % args.dst.addrstr)
            linux_intf = pktTrace.getOutLinuxIntf()
        sys.stdout.write('%s\n' % linux_intf)
        sys.exit(0)

    if mlx:
        raise UnsupportedPlatformError('ecmpcalc is not available on %s' % chip)

    # Read values from hardware to fill in missing hash configuration information.
    if args.hardware:
        rtag7config = RTAG7Config(bcmshell.bcmshell())
        rtag7config.load()

        if args.hashseed is None:
            setattr(args, 'hashseed', HashSeed.from_int(rtag7config.getHashSeed(frame)))
            if args.verbose:
                sys.stderr.write('hardware hashseed    : 0x%08x\n' % args.hashseed.value)
        if args.hashfunction is None:
            setattr(args, 'hashfunction', rtag7config.getHashFunction(frame))
            if args.verbose:
                sys.stderr.write('hardware hashfunction: %s\n' % args.hashfunction)
        if args.hashfields is None:
            setattr(args, 'hashfields', rtag7config.getHashFields(frame))
            if args.verbose:
                sys.stderr.write('hardware hashfields  : %s\n' % args.hashfields)
        
    # Complain when we don't have enough of the hash configuration to proceed.
    if args.hashseed is None:
        parser.error('--hashseed required')
    if args.hashfunction is None:
        parser.error('--hashfunction required')
    if args.hashfields is None:
        args.hashfields = []

    if args.protocol is not None and (args.protocol.value > 255 or
                                     args.protocol.value < 0):
        parser.error('Invalid protocol value %d\n' % (args.protocol.value))
    # Attempt to calculate the crc.
    rtag7hash = RTAG7Hash(args.hashfunction, args.hashseed, args.hashfields)
    try:
        frameHash = rtag7hash.calculate(frame)
    except MissingFrameInformation, missing:
        parser.error('unable to calculate hash.  Frame information is incomplete.\n'
                     'Please specify additional options for: %s' % ', '.join(*missing))

    if args.egress and not asicShellAvailable:
        parser.error('unable to load multipath egress object from hardware')

    if args.dst is None and args.egress is None and args.mcount is None:
        parser.error('need a destination IP OR multipath egress object OR a '
                     'ECMP group member count')

    if (args.src is None or args.dst is None or
            args.sport is None or args.dport is None or args.protocol is None
            or args.hardware is None or args.in_interface is None):
        parser.error('please specify following options for'
                ' hashing: -s, -d, --sport, --dport, -p, -i, '
                '--hardware')

    if args.dst is not None:
        egress = EgressTable(bcmshell.bcmshell(timeout=60), verbose=args.verbose)
        dstAddr = ipaddr.IPAddress(args.dst.addrstr)
        egressObj = egress.findEgressObjectByDst(dstAddr)
        ecmp = ECMPConfig(bcmshell.bcmshell(), verbose=args.verbose)

        if ecmp.objectIsECMP(egressObj.egressId) is False:
            raise RuntimeError('Error: traffic to %s will not egress ECMP' % args.dst.addrstr)

        if dstAddr.version is 4:
            ethertype = 0x0800
            hd_cmd = ('hd dest get Group=ECMP GID=%s Port=%s SrcIp=%s '
                    'DestIp=%s L4SrcPort=%s L4DstPort=%s Ethertype=%s '
                    'Protocol=%s' % (egressObj.egressId, args.sportid.value,
                        args.src.addrstr, args.dst.addrstr, args.sport.value,
                        args.dport.value, ethertype, args.protocol.value))
        elif dstAddr.version is 6:
            ethertype = 0x86dd
            srcNet = ipaddr.IPNetwork(args.src.addrstr)
            dstNet = ipaddr.IPNetwork(args.dst.addrstr)
            hd_cmd = ('hd dest get Group=ECMP GID=%s Port=%s SIp6=%s '
                    'DIp6=%s L4SrcPort=%s L4DstPort=%s Ethertype=%s '
                    'Protocol=%s' % (egressObj.egressId, args.sportid.value,
                        srcNet.ip.exploded, dstNet.ip.exploded, args.sport.value,
                        args.dport.value, ethertype, args.protocol.value))
        isTrunkMbr, port = ecmp.getHdPort(hd_cmd)

        linux_intf = ecmp.getLinuxIntf(isTrunkMbr, port)
        sys.stdout.write('%s\n' % linux_intf)
