#!/usr/bin/python
# Copyright 2014,2016,2019 Cumulus Networks, Inc.
#
# portsamp
#
# Utility to configure packet sampling on physical interfaces.
#
import sys
import cumulus.porttab as porttab
import os.path
import subprocess

try:
    import bcmshell
    HAS_BCM_SHELL = True
except ImportError:
    HAS_BCM_SHELL = False

#Function to detect the vendor platform

def _chip_detect():
    _chip_map = { 'b768' : 'Maverick',
                  'b762' : 'Maverick',
                  'b760' : 'Maverick',
                  'b963' : 'Tomahawk',
                  'b962' : 'Tomahawk',
                  'b960' : 'Tomahawk',
                  'b965' : 'TomahawkPlus',
                  'b967' : 'TomahawkPlus',
                  'b970' : 'Tomahawk2',
                  'b980' : 'Tomahawk3',
                  'b870' : 'Trident3X7',
                  'b873' : 'Trident3X7',
                  'b864' : 'Trident2Plus',
                  'b860' : 'Trident2Plus',
                  'b854' : 'Trident2',
                  'b850' : 'Trident2',
                  'b846' : 'Trident',
                  'b845' : 'Trident',
                  'b340' : 'Helix4',
                  'b150' : 'Hurricane2',
                  '8375' : 'QumranMX',
                }
    devices = open('/proc/bus/pci/devices', 'r').readlines()
    if len(devices) == 0:
        #No PCI bus on embedded ARM platforms, 
        #currently all these use Helix4
        return 'BCM'

    for device in devices:
        (foo, pci_id, bar) = device.split(None, 2)
        vendor = pci_id[:4]
        product = pci_id[4:]
        if vendor == '14e4' and product in _chip_map:
            return  'BCM'
        elif vendor == '15b3' :
            return 'MLX'

    if 'cumulus,vx' == subprocess.check_output('/usr/lib/cumulus/cl-platform'):
       return 'VX'

    sys.stderr.write("Unknown Vendor");
    sys.exit(10);

def usage():
    sys.stderr.write("Usage: %s (swp##|swp*) ingr_rate egr_rate "
                     "psample_group\n" % sys.argv[0])
    sys.stderr.write("  On average, every 1/rate packets will be sampled.  A\n")
    sys.stderr.write("  rate of 0 means no sampling, 1 means all packets sampled.\n")
    sys.stderr.write("  Most recommend monitoring only ingress traffic, but if\n")
    sys.stderr.write("  ingr_rate and egr_rate are specified then a psample_group\n")
    sys.stderr.write("  + 1 is allocated for egress samples.\n")
    sys.exit(10)

def add_clsact_qdisc_if_not_exist(swp_intf):
    """
    Add a clsact qdisc to the specified interface if there is no such
    one already
    """
    existing_qdiscs_cmd = 'tc qdisc show dev %s' % (swp_intf)
    existing_qdiscs = subprocess.Popen(existing_qdiscs_cmd, stdout=subprocess.PIPE, shell=True)
    check_for_qdisc_cmd = 'grep --quiet clsact'
    qdisc_found = subprocess.call(check_for_qdisc_cmd, stdin=existing_qdiscs.stdout, shell=True) == 0

    if not qdisc_found:
        add_qdisc_cmd = 'tc qdisc add dev %s clsact' % (swp_intf)
        ret = subprocess.call(add_qdisc_cmd, shell=True)
        if ret > 0:
            sys.exit("Adding clsact qdisc for %s failed\n\n" % (swp_intf))

def add_matchall_filter_sample_action(swp_intf, rate, psample_group, hw_offload, for_ingress):
    """
    Add a matchall filter with sample action to the specified interface.
    If rate == 0, delete the existing filters
    """
    channel = "ingress" if for_ingress else "egress"
    skip_sw = "skip_sw" if hw_offload else ""

    if rate > 0:
        add_sample_filter_cmd = "tc filter add dev %s %s matchall %s \
            action sample rate %d group %d" % (swp_intf, channel, skip_sw, rate, psample_group)
        ret = subprocess.call(add_sample_filter_cmd, shell=True)
    else:
        del_filters_cmd = "tc filter del dev %s %s" % (swp_intf, channel)
        ret = subprocess.call(del_filters_cmd, shell=True)

    if ret > 0:
        sys.exit("tc sample filter could not be modified for %s %s\n\n" % (swp_intf, channel))


if len(sys.argv) != 5:
    usage()

intf = sys.argv[1]
try:
    ingr_rate = int(sys.argv[2])
    egr_rate = int(sys.argv[3])
    psample_group = int(sys.argv[4])
    psample_group_egr = psample_group + 1
except:
    sys.stderr.write("Rates and psample_group must be integers.\n\n")
    usage()
if psample_group < 0 or psample_group > 65535:
    sys.stderr.write("psample must be between 1 and 65535\n\n")
    usage()

pt = porttab.porttab()
if intf == "swp*":
    interfaces = pt.get_linux_ports()
else :
    interfaces = [intf]

platform = _chip_detect()
if not (platform == 'MLX' or platform == 'VX') and not HAS_BCM_SHELL:
    sys.exit("Platform is not supported.")

if (platform == 'MLX' or platform == 'VX') and egr_rate > 0:
    sys.stderr.write("Non-zero egr_rate not supported on Mellanox/VX platforms\n\n")
    sys.exit(10)

for swp_intf in interfaces:
    hw_offload = platform != 'VX'
    if hw_offload:
        # offload sampling to ASIC
        hw_offload = True
        with open('/cumulus/switchd/config/traffic/sflow/portsamprate_set', 'w') as f:
            f.write('%s: %d, %d\n' % (swp_intf, ingr_rate, egr_rate));

    # install tc rules to the kernel mirroring ASIC sampling configuration
    add_clsact_qdisc_if_not_exist(swp_intf)
    add_matchall_filter_sample_action(swp_intf, ingr_rate, psample_group, hw_offload, for_ingress=True)
    add_matchall_filter_sample_action(swp_intf, egr_rate, psample_group_egr, hw_offload, for_ingress=False)
