#!/usr/bin/python
# Copyright (C) 2017, 2018, 2019 Cumulus Networks, Inc. all rights reserved
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; version 2.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
#
# https://www.gnu.org/licenses/gpl-2.0-standalone.html
#
# Author:
#       Julien Fortin, julien@cumulusnetworks.com
#
# ifupdown2 --
#    tool to configure network interfaces
#

import os
import re
import sys
import json
import fcntl
import signal
import socket
import daemon
import logging
import argparse
import datetime
import threading

from io import StringIO

try:
    import cumulus.sdnotify
except:
    pass

try:
    from ifupdown2.lib.io import SocketIO
    from ifupdown2.lib.status import Status
    from ifupdown2.lib.log import LogManager, root_logger
    from ifupdown2.lib.exceptions import ExitWithStatusAndError, ExitWithStatus

    from ifupdown2.ifupdown.argv import Parse
    from ifupdown2.ifupdown.main import Ifupdown2
    from ifupdown2.ifupdown.exceptions import ArgvParseError, ArgvParseHelp
except:
    from .lib.io import SocketIO
    from .lib.status import Status
    from .lib.log import LogManager, root_logger
    from .lib.exceptions import ExitWithStatusAndError, ExitWithStatus

    from .ifupdown.argv import Parse
    from .ifupdown.main import Ifupdown2
    from .ifupdown.exceptions import ArgvParseError, ArgvParseHelp


class Daemon(SocketIO):
    shutdown_event = threading.Event()

    def __init__(self):
        SocketIO.__init__(self)

        # daemon specific argv parser
        argparser = argparse.ArgumentParser()
        argparser.add_argument(
            "-d", "--debug",
            dest="debug",
            action="store_true",
            help="enable debug logging"
        )
        argparser.add_argument(
            "-v", "--verbose",
            dest="verbose",
            action="store_true",
            help="enable verbose logging"
        )
        argparser.add_argument(
            "--console",
            dest="console",
            action="store_true",
            help="enable logging on stderr"
        )
        argparser.add_argument(
            "--no-daemon",
            dest="no_daemon",
            action="store_true",
            help="prevent process from daemonizing"
        )
        self.args = argparser.parse_args()
        LogManager.get_instance().start_daemon_logging(self.args)

        root_logger.info("ifupdown2 daemon initialization...")

        self.uds = None
        self.context = None
        self.server_address = "/var/run/ifupdown2d/uds"
        self.working_directory = "/var/run/ifupdown2d/"

        self.signal_str_map = dict(
            (attr_value, attr_name)
            for attr_name, attr_value in list(signal.__dict__.items())
            if attr_name.startswith("SIG") and not attr_name.startswith("SIG_")
        )

        if not os.path.exists(self.working_directory):
            root_logger.info("creating %s" % self.working_directory)
            os.makedirs(self.working_directory, mode=0o755)

        if os.path.exists(self.server_address):
            root_logger.info("removing uds %s" % self.server_address)
            os.remove(self.server_address)

        try:
            self.SO_PEERCRED = socket.SO_PEERCRED
        except AttributeError:
            # powerpc is the only non-generic we care about. alpha, mips,
            # sparc, and parisc also have non-generic values.
            machine = os.uname()[4]
            if re.search(r'^(ppc|powerpc)', machine):
                self.SO_PASSCRED = 20
                self.SO_PEERCRED = 21
            else:
                self.SO_PASSCRED = 16
                self.SO_PEERCRED = 17

        self.__signal_map = {
            signal.SIGINT: self.__signal_handler,
            signal.SIGTERM: self.__signal_handler,
            signal.SIGQUIT: self.__signal_handler,
        }

        # for debugging purpose we accept the option "--no-daemon"
        if not self.args.no_daemon:
            root_logger.info("daemonizing ifupdown2d...")

            self.context = daemon.DaemonContext(
                working_directory=self.working_directory,
                signal_map=self.__signal_map,
                umask=0o22
            )
            self.context.open()
            LogManager.get_instance().disable_console()

        try:
            root_logger.info("opening UNIX socket")
            self.uds = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            fcntl.fcntl(self.uds.fileno(), fcntl.F_SETFD, fcntl.FD_CLOEXEC)
        except Exception as e:
            raise ExitWithStatusAndError(Status.Daemon.STATUS_SOCKET_ERROR, "socket: %s" % str(e))
        try:
            self.uds.bind(self.server_address)
        except Exception as e:
            raise ExitWithStatusAndError(Status.Daemon.STATUS_SOCKET_ERROR, "bind: %s" % str(e))
        try:
            self.uds.setsockopt(socket.SOL_SOCKET, self.SO_PASSCRED, 1)
        except Exception as e:
            raise ExitWithStatusAndError(Status.Daemon.STATUS_SOCKET_ERROR, "setsockopt: %s" % str(e))
        try:
            self.uds.listen(1)
        except Exception as e:
            raise ExitWithStatusAndError(Status.Daemon.STATUS_SOCKET_ERROR, "listen: %s" % str(e))

        os.chmod(self.server_address, 0o777)

        # we need to divert stdout and stderr
        self.__stdout_buffer = None
        self.__stderr_buffer = None
        self.__reset_buffers()

        root_logger.info("filling netlink cache")
        try:
            import ifupdown2.lib.nlcache as nlcache
        except:
            from . import lib.nlcache as nlcache

        # After the daemon is ready we can start our netlink listener
        nlcache.NetlinkListenerWithCache.init(logging.root.level)

        # save reference to nlcache
        nlcache_ref = nlcache.NetlinkListenerWithCache.get_instance()

        # start netlink listener and cache link/addr/netconf dumps
        nlcache_ref.start()

    @staticmethod
    def __notify_systemd_ready():
        """
        Cumulus specific
        :return:
        """
        try:
            # Tell systemd that we are initialized and ready
            cumulus.sdnotify.sd_notify(0, "READY=1")
            root_logger.info("ifupdown2d is ready to process incoming requests")
        except:
            pass

    def __signal_handler(self, sig, frame):
        """
        Signal handler, shutdown the daemon on any incoming signal and logs.
        :param sig:
        :param frame:
        :return:
        """
        if sig == signal.SIGTERM:
            logger_func = root_logger.info
        else:
            logger_func = root_logger.warning

        logger_func(
            "received %s signal" %
            self.signal_str_map.get(sig, "UNKNOWN")
        )
        Daemon.shutdown_event.set()

    def __shutdown_session(self, client_socket, status):
        """ TX exit stdin, stderr and status to client before closing socket """
        try:
            self.tx_data(client_socket, json.dumps({
                "stdout": self.__stdout_buffer.getvalue(),
                "stderr": self.__stderr_buffer.getvalue(),
                "status": status
            }))
        finally:
            client_socket.close()
            self.__reset_buffers()

    def __reset_buffers(self):
        self.__stdout_buffer = sys.stdout = StringIO()
        self.__stderr_buffer = sys.stderr = StringIO()

    def run(self):
        self.__notify_systemd_ready()
        try:
            for sig, sig_handler in list(self.__signal_map.items()):
                signal.signal(sig, sig_handler)

            while True:
                if Daemon.shutdown_event.is_set():
                    root_logger.info("shutdown signal RXed, breaking out loop")
                    break
                try:
                    client_socket, client_address = self.uds.accept()
                    # sets the close-on-exec flag for the file descriptor,
                    # which causes the fd to be automatically closed when any
                    # of the exec-family functions succeed.
                    fcntl.fcntl(client_socket.fileno(), fcntl.F_SETFD, fcntl.FD_CLOEXEC)
                    socket_peer_cred = self.get_socket_peer_cred(client_socket)
                except socket.error as e:
                    if not Daemon.shutdown_event.is_set():
                        root_logger.exception("daemon: socket: accept: %s" % str(e))
                    else:
                        root_logger.info("shutdown signal RXed, breaking out loop")
                    break

                LogManager.get_instance().start_stream()

                status = Status.Daemon.STATUS_INIT
                try:
                    # RX client request (dict with argv and stdin)
                    request = self.rx_json_packet(client_socket)

                    start = datetime.datetime.now()

                    # process request
                    status = self.process_request(request, socket_peer_cred)

                    end = datetime.datetime.now()

                    # restore syslog handler
                    LogManager.get_instance().enable_syslog()

                    # set level to INFO to emit the exit status message
                    LogManager.get_instance().set_level(logging.INFO, info=True)

                    root_logger.info("exit status %d - in %ssecs" %
                                     (status, (end - start).total_seconds()))
                finally:
                    # After processing the request and getting a status code
                    # we need to send the status back to the client as well as
                    # our stdin/stderr buffers.
                    LogManager.get_instance().close_log_stream()
                    self.__shutdown_session(client_socket, status)
                    # restore original daemon logging level
                    LogManager.get_instance().set_daemon_logging_level(self.args)
        finally:
            self.uds.close()
        return 0

    @staticmethod
    def process_request(request, socket_peer_cred):
        try:
            (pid, uid, gid) = socket_peer_cred
            request_argv = request.get("argv")
            request_stdin = request.get("stdin")

            # this might get logged twice to syslog if the client has syslog on
            root_logger.info("processing incoming request from pid %s: %s"
                             % (pid, " ".join(request_argv)))

            ifupdown2 = Ifupdown2(daemon=True, uid=uid)
            ifupdown2.parse_argv(request_argv)

            # adjust the log level and handler for the request
            LogManager.get_instance().set_request_logging_level(ifupdown2.args)

            try:
                status = ifupdown2.main(request_stdin)
            except BaseException as e:
                root_logger.error('ifupdown2.main: %s' % str(e))
                status = 1
                return status

        except ArgvParseHelp:
            # on --help parse_args raise SystemExit we catch it and raise
            # a custom exception ArgvParseHelp so we can properly return 0
            status = 0
        except ArgvParseError as e:
            e.log_error()
            status = Status.Daemon.STATUS_REQUEST_PARSE_ERROR
        except Exception as e:
            root_logger.exception("exception: %s" % str(e))
            status = Status.Daemon.STATUS_REQUEST_EXCEPTION
        except BaseException as e:
            root_logger.error("base exception %s" % str(e))
            status = Status.Daemon.STATUS_REQUEST_BASE_EXCEPTION
        return status


def main():
    try:
        status = Daemon().run()
    except ExitWithStatusAndError as e:
        root_logger.error(e.message)
        status = e.status
    except ExitWithStatus as e:
        status = e.status
    except KeyboardInterrupt:
        status = Status.Daemon.STATUS_KEYBOARD_INTERRUPT
    except BaseException as e:
        root_logger.exception("%s" % str(e))
        status = Status.Daemon.STATUS_UNKNOWN
    try:
        import ifupdown2.lib.nlcache as nlcache
    except:
        from . import lib.nlcache as nlcache
    nlcache.NetlinkListenerWithCache.get_instance().cleanup()
    return status


if __name__ == "__main__":
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        sys.exit(Status.Daemon.STATUS_KEYBOARD_INTERRUPT)
