#!/usr/bin/python

# Copyright 2020.  Cumulus Networks, Inc.
#
# nfr2ebtables --
#    tool to convert IPFilterRules to EBTABLES format based
#    on RFC-6733. This is mainly used by hostapd to convert
#    NAS-Filter-Rule attributes into installable EBTABLES ACLs
#


try:
    import re
    import sys
    import argparse
    import syslog
    import os
    import ipaddr
    import pprint
except:
    print("ERROR: could not import required modules.")
    exit(-1)

'''
The IPFilterRule syntax comes from RFC-6733, page 48
IPFilterRule filters MUST follow the format:

            action dir proto from src to dst [options]

pattern with options on the end

mark 2: dot1x normal filter ...go to the slice
mark 3: only for MAB: dot1x special rule in dot1x slice but leak 2 packets to CPU and drop

by default we will deny all...regardless of what they specify or the default

IPFilterRule:
    permit in ip from any to 8.8.8.8/32 tcp 67
    deny   in ip from any to any

EBTABLES Rules:  (Note the insertion of the ifname and source MAC address.
    -A FORWARD -i swp2s1 -s 00:02:00:00:00:07 -p IPV4 --ip-dst 8.8.8.8/32 --ip-protocol TCP --ip-dport 67 -j mark --set-mark 2
    -A FORWARD -i swp2s1 -s 00:02:00:00:00:07 -p IPV4 --ip-dst 8.8.8.8/32 --ip-protocol TCP --ip-dport 67 -j ACCEPT


We can ignore the deny all because the pre-auth filters must have a deny all at the end
-A FORWARD -i swp1 -j mark --set-mark 3
    -A FORWARD -i swp1 -j DROP

'''

def logger(level=syslog.LOG_DEBUG, message=''):
    syslog.syslog(level, message)
    print(message)

def parser_init():
    parser = argparse.ArgumentParser(description="nfr2ebtables: convert from IPFilterRule to EBTABLES syntax")
    parser.add_argument("-s", "--switchport", help="switchport interface name to use in ebtable filter (e.g. swp1)")
    parser.add_argument("-m", "--srcmac", help="source MAC address to use in ebtable filter (e.g. 00:e0:ec:27:52:6b)")
    parser.add_argument("-i", "--inputfilename", help="input IPFilterRule filename (e.g. inputfile.txt)")
    parser.add_argument("-o", "--outputfilename", help="output EBTABLES filename (e.g. outputebtable.rule)")
    parser.add_argument("-d", "--debug", help="log DEBUG messages", action="store_true")
    return(parser.parse_args())

def rule_checker(m):
    '''
    check some of the parsed rules for invalid entries.
    Returns True if there are errors, False if no problems
    '''
    #import pdb;pdb.set_trace()
    if (m.group("dir").lower() != "in"):
        logger(level=syslog.LOG_ERR, message="ERROR: IPFilterRule only supports 'in' direction")
        return(True)

    # we support any protocol as this is an IPTABLES filter, we will take
    # numbers or specific L4 protocols (or "ip")

    if (m.group("src").lower() != "any"):
        logger(level=syslog.LOG_ERR, message="ERROR: IPFilterRule only supports 'any' src. This is replaced by the SRC MAC address")
        return(True)

    if (m.group("dir").lower() != "in"):
        logger(level=syslog.LOG_ERR, message="ERROR: IPFilterRule only supports 'in' direction.")
        logger(level=syslog.LOG_ERR, message="ERROR:     This filters traffic from the user port")
        return(True)

    if (m.group("action").lower() not in ["permit", "deny"]):
        logger(level=syslog.LOG_ERR, message="ERROR: IPFilterRule only supports permit or deny actions.")
        logger(level=syslog.LOG_ERR, message="ERROR:     This filters traffic from the user port")
        return(True)

    if (m.group("dst") != "any"):
        # only support IPv4 dst network and addresses or just any
        try:
            if isinstance(ipaddr.IPv4Network(m.group("dst")), ipaddr.IPv4Network):
                # this is a valid IPv4 address or network, continue
                pass
        except:
            #import pdb;pdb.set_trace()
            logger(level=syslog.LOG_ERR, message="ERROR: IPFilterRule only supports destination IPv4 address or network")
            logger(level=syslog.LOG_ERR, message="ERROR:    address {} invalid".format(m.group("dst")))
            return(True)

    if m.group("options"):
        options = m.group("options")
        if len(options.split(' ')) > 1:
            therest = ' '.join(options.split(' ')[1:])
            logger(level=syslog.LOG_ERR, message="ERROR: only a single L4 destination port option is supported.")
            logger(level=syslog.LOG_ERR, message="ERROR:     '{}' will be ignored".format(therest))
            return(True)

        port = options.split(' ')[0]
        if ('-' in port or '/' in port or ',' in port):
            logger(level=syslog.LOG_ERR, message="ERROR: L4 destination port option does not support ranges")
            return(True)

        # now we make sure the options contains just an integer
        try:
            port = options.split(' ')[0]
            port = int(port)
            # we should be ok
        except:
            logger(level=syslog.LOG_ERR, message="ERROR: L4 destination port option must be an integer")
            logger(level=syslog.LOG_ERR, message="ERROR:     {} is not an integer".format(port))
            return(True)

    return(False)


def get_ebtables_lines(iface=None, srcmac=None, match=None):
    '''
    # test rule file
    permit in udp from any to any 67,
    permit in udp from any to 10.128.0.0/9 53,
    permit in udp from any to 10.128.0.0/9 123,
    permit in icmp from any to any,
    permit in ip from any to 165.130.181.99,
    permit in ip from any to 165.130.131.99,
    permit in ip from any to 165.130.181.33,
    permit in ip from any to 165.130.131.105,
    permit in ip from any to 10.72.169.224,
    permit in ip from any to 10.72.168.142,
    permit in tcp from any to 10.128.0.0/9 8883,
    permit in tcp from any to 10.128.0.0/9 32768 61000,
    permit in tcp from any to 10.128.0.0/9

    parsed with

    DEBUG: match found: {'src': 'any', 'proto': 'udp', 'dst': '10.128.0.0/9', 'space': ' ', 'comma': ',', 'action': 'permit', 'options': '53', 'dir': 'in'}

    to generate
        -A FORWARD -i swp2s1 -s 00:02:00:00:00:07 -p IPV4 --ip-dst 8.8.8.8/32 --ip-protocol TCP --ip-dport 67 -j mark --set-mark 2
        -A FORWARD -i swp2s1 -s 00:02:00:00:00:07 -p IPV4 --ip-dst 8.8.8.8/32 --ip-protocol TCP --ip-dport 67 -j ACCEPT


    '''
    if (not match or not iface or not srcmac):
        logger(level=syslog.LOG_ERR, message="ERROR: missing iface, srcmac, or matched regex to convert")
        return None

    # Since the checks have alrady been done, we can proceed.
    # start a small buffer so we can add the 2 lines we need.
    #ebfbuf = '-A FORWARD -i {} -s {} -p IPV4 --ip-dst {}'.format(iface, srcmac, m.group("dst"))
    ebfbuf = '-A FORWARD -i {} -s {} -p IPV4 '.format(iface, srcmac)

    # add in a --ip-dst if it is not any since any can be left out
    if (m.group("dst").lower() != 'any'):
        ebfbuf = "{} --ip-dst {}".format(ebfbuf, m.group("dst"))

    # add in a protocol if it's not ip. This has to exist because it was parsed.
    if (m.group("proto").lower() != 'ip'):
        ebfbuf = "{} --ip-protocol {}".format(ebfbuf, m.group("proto").upper())

    # if a dst port exists, it has to be a number in options.
    if (m.group("options") != None):
        ebfbuf = "{} --ip-dport {}".format(ebfbuf, int(m.group("options")))

    # now add the DROP or ACCEPT. This was checked earlier.
    filter = ''
    if (m.group("action").lower() == "permit"):
        filter = "{} -j ACCEPT".format(ebfbuf)
    elif (m.group("action").lower() == "deny"):
        filter = "{} -j DROP".format(ebfbuf)

    # now the mark
    mark = "{} -j mark --set-mark 2".format(ebfbuf)

    if filter != '':
        return([mark,filter])
    else:
        return([])



# this pattern should remove any trailing commas
pattern = r'''(?P<action>\S+)\s+(?P<dir>\S+)\s+(?P<proto>\S+)\s+from\s+(?P<src>\S+)\s+to\s+(?P<dst>\S+)?(?P<space>\s+)?(?P<options>[^,;]+)?'''
#denyallpattern = r'''.*deny\s+in\s+ip\s+from\s+any\s+to\s+any.*'''
default_policy_dir = "/etc/cumulus/acl/policy.d/tmp/"

if __name__ == "__main__":
    parser_args = parser_init()

    if (parser_args.debug):
        logger(message="DEBUG: sys args = {}".format(sys.argv))

    if not parser_args.inputfilename or not parser_args.srcmac or not parser_args.switchport:
        logger(level=syslog.LOG_ERR, message="ERROR: inputfilename, switchport, and src MAC are required")
        logger(level=syslog.LOG_ERR, message=parser.print_help())
        sys.exit(-1)

    if not os.path.isfile(parser_args.inputfilename):
        logger(level=syslog.LOG_ERR, message= "ERROR: file {} is missing".format(parser_args.inputfilename))
        sys.exit(-1)

    if (parser_args.debug):
        logger(message="DEBUG: args given: {}".format(parser_args))

    with open(parser_args.inputfilename) as f:
        ebtablesbuf = ["######## hostapd generated Dynamic ACL EBTABLES rule file ########",
                       "[ebtables]"]
        for line in f:
            line = line.strip()
            if line.startswith('#'):
                continue

            # just remove trailing commas the easy way
            if line[-1] == ',':
                line = line[:-1]

            if '!' in line:
                logger(level=syslog.LOG_ERR, message="ERROR: IPFilterRule cannot contain inverse '!'}".format(line))
                sys.exit(-1)

            if "cnt" in line:
                logger(level=syslog.LOG_WARNING, message="WARNING: 'cnt' ignored in IPFilterRule")
                line = line.replace("cnt",'')
                line = line.strip()

            if (parser_args.debug):
                logger(message="DEBUG: input line: {}".format(line))

            # we leave the deny all since this should be ok for that MAC and swp

            m = re.match(pattern, line)
            if m:
                # we have a match, so handle it
                if (parser_args.debug):
                    logger(message="DEBUG: match found: {}".format(m.groups()))
                    logger(message="DEBUG: match found: {}".format(m.groupdict()))

                if (rule_checker(m)):
                    logger(level=syslog.LOG_ERR, message="ERROR: Invalid IPFilter rule line {}".format(line))
                    sys.exit(-1)

                eblines = get_ebtables_lines(iface=parser_args.switchport, srcmac=parser_args.srcmac, match=m)
                if eblines:
                    ebtablesbuf.extend(eblines)

            else:
                logger(level=syslog.LOG_ERR, message="Error: cannot parse invalid IPFilterRule {}".format(line))
                sys.exit(-1)

        if (len(ebtablesbuf) > 1):
            # write out the file
            if (parser_args.debug):
                logger(message="DEBUG: output ebtables filter:\n{}".format(pprint.pprint(ebtablesbuf)))

            if (parser_args.outputfilename):
                # output filename was given so we use it.  Generally, we force
                # the filename to be
                # /etc/cumulus/acl/policy.d/dot1x_dacl/dot1x_dacl_swpN_AABBCCDDEEFF.rule
                outputfilename = parser_args.outputfilename
            else:
                # remove colons from MAC address
                outputfilename = "{}/150_dot1x_dacl_{}_{}.rules".format(default_policy_dir,
                                                                        parser_args.switchport,
                                                                        parser_args.srcmac.replace(':',''))
                outputfilename = os.path.normcase(os.path.normpath(outputfilename))

            if (parser_args.debug):
                logger(message="DEBUG: writing output ebtables file: {}".format(outputfilename))

            try:
                with open(outputfilename, "w") as f:
                    for line in ebtablesbuf:
                        f.write("{}\n".format(line))
            except Exception as e:
                logger(level=syslog.LOG_ERR, message="Error: Could not write EBTABLES rule file {}".format(outputfilename))
                logger(level=syslog.LOG_ERR, message="Error {0}: {1}".format(e.errno, e.strerror))
                sys.exit(-1)




            #pprint.pprint(ebtablesbuf)









