/*
 * lib/route/mroute.c	Multicast Routes
 *
 *	This library is free software; you can redistribute it and/or
 *	modify it under the terms of the GNU Lesser General Public
 *	License as published by the Free Software Foundation version 2.1
 *	of the License.
 *
 * Copyright (c) 2016-2017 Roopa Prabhu <roopa@cumulusnetworks.com>
 */

/**
 * @ingroup rtnl
 * @defgroup route Routing
 * @brief
 * @{
 */

#include <netlink-private/netlink.h>
#include <netlink/netlink.h>
#include <netlink/cache.h>
#include <netlink/utils.h>
#include <netlink/data.h>
#include <netlink/route/rtnl.h>
#include <netlink/route/route.h>
#include <netlink/route/mroute.h>
#include <netlink/route/link.h>

static struct nl_cache_ops rtnl_mroute_ops;

/** @cond SKIP */
#define ROUTE_ATTR_FAMILY    0x000001
#define ROUTE_ATTR_TOS       0x000002
#define ROUTE_ATTR_TABLE     0x000004
#define ROUTE_ATTR_PROTOCOL  0x000008
#define ROUTE_ATTR_SCOPE     0x000010
#define ROUTE_ATTR_TYPE      0x000020
#define ROUTE_ATTR_FLAGS     0x000040
#define ROUTE_ATTR_DST       0x000080
#define ROUTE_ATTR_SRC       0x000100
#define ROUTE_ATTR_IIF       0x000200
#define ROUTE_ATTR_OIF       0x000400
#define ROUTE_ATTR_GATEWAY   0x000800
#define ROUTE_ATTR_PRIO      0x001000
#define ROUTE_ATTR_PREF_SRC  0x002000
#define ROUTE_ATTR_METRICS   0x004000
#define ROUTE_ATTR_MULTIPATH 0x008000
#define ROUTE_ATTR_REALMS    0x010000
#define ROUTE_ATTR_CACHEINFO 0x020000
/** @endcond */

static struct nla_policy mroute_policy[RTA_MAX+1] = {
	[RTA_IIF]	= { .type = NLA_U32 },
	[RTA_OIF]	= { .type = NLA_U32 },
	[RTA_PRIORITY]	= { .type = NLA_U32 },
	[RTA_FLOW]	= { .type = NLA_U32 },
	[RTA_CACHEINFO]	= { .minlen = sizeof(struct rta_cacheinfo) },
	[RTA_METRICS]	= { .type = NLA_NESTED },
	[RTA_MULTIPATH]	= { .type = NLA_NESTED },
};

static int parse_multipath(struct rtnl_route *route, struct nlattr *attr)
{
	struct rtnl_nexthop *nh = NULL;
	struct rtnexthop *rtnh = nla_data(attr);
	size_t tlen = nla_len(attr);
	int err;

	while (tlen >= sizeof(*rtnh) && tlen >= rtnh->rtnh_len) {
		nh = rtnl_route_nh_alloc();
		if (!nh)
			return -NLE_NOMEM;

		rtnl_route_nh_set_weight(nh, rtnh->rtnh_hops);
		rtnl_route_nh_set_ifindex(nh, rtnh->rtnh_ifindex);
		rtnl_route_nh_set_flags(nh, rtnh->rtnh_flags);

		if (rtnh->rtnh_len > sizeof(*rtnh)) {
			struct nlattr *ntb[RTA_MAX + 1];

			err = nla_parse(ntb, RTA_MAX, (struct nlattr *)
					RTNH_DATA(rtnh),
					rtnh->rtnh_len - sizeof(*rtnh),
					mroute_policy);
			if (err < 0)
				goto errout;

			if (ntb[RTA_GATEWAY]) {
				struct nl_addr *addr;

				addr = nl_addr_alloc_attr(ntb[RTA_GATEWAY],
							  route->rt_family);
				if (!addr) {
					err = -NLE_NOMEM;
					goto errout;
				}

				rtnl_route_nh_set_gateway(nh, addr);
				nl_addr_put(addr);
			}

			if (ntb[RTA_FLOW]) {
				uint32_t realms;
				
				realms = nla_get_u32(ntb[RTA_FLOW]);
				rtnl_route_nh_set_realms(nh, realms);
			}
		}

		rtnl_route_add_nexthop(route, nh);
		tlen -= RTNH_ALIGN(rtnh->rtnh_len);
		rtnh = RTNH_NEXT(rtnh);
	}

	err = 0;
errout:
	if (err && nh)
		rtnl_route_nh_free(nh);

	return err;
}

int rtnl_mroute_parse(struct nlmsghdr *nlh, struct rtnl_route **result)
{
	struct rtmsg *rtm;
	struct rtnl_route *route;
	struct nlattr *tb[RTA_MAX + 1];
	struct nl_addr *src = NULL, *dst = NULL, *addr;
	struct rtnl_nexthop *old_nh = NULL;
	int err, family;

	route = rtnl_mroute_alloc();
	if (!route) {
		err = -NLE_NOMEM;
		goto errout;
	}

	route->ce_msgtype = nlh->nlmsg_type;
	route->ce_msgflags = nlh->nlmsg_flags;

	err = nlmsg_parse(nlh, sizeof(struct rtmsg), tb, RTA_MAX, mroute_policy);
	if (err < 0)
		goto errout;

	rtm = nlmsg_data(nlh);
	route->rt_family = family = rtm->rtm_family;
	route->rt_tos = rtm->rtm_tos;
	route->rt_table = rtm->rtm_table;
	route->rt_type = rtm->rtm_type;
	route->rt_scope = rtm->rtm_scope;
	route->rt_protocol = rtm->rtm_protocol;
	route->rt_flags = rtm->rtm_flags;
	route->rt_prio = 0;

	route->ce_mask |= ROUTE_ATTR_FAMILY | ROUTE_ATTR_TOS |
			  ROUTE_ATTR_TABLE | ROUTE_ATTR_TYPE |
			  ROUTE_ATTR_SCOPE | ROUTE_ATTR_PROTOCOL |
			  ROUTE_ATTR_FLAGS | ROUTE_ATTR_PRIO;

	if (tb[RTA_DST]) {
		if (!(dst = nl_addr_alloc_attr(tb[RTA_DST], family)))
			goto errout_nomem;
	} else {
		if (!(dst = nl_addr_alloc(0)))
			goto errout_nomem;
		nl_addr_set_family(dst, rtm->rtm_family);
	}

	nl_addr_set_prefixlen(dst, rtm->rtm_dst_len);
	err = rtnl_route_set_dst(route, dst);
	if (err < 0) {
		nl_addr_put(dst);
		goto errout;
	}

	nl_addr_put(dst);

	if (tb[RTA_SRC]) {
		if (!(src = nl_addr_alloc_attr(tb[RTA_SRC], family)))
			goto errout_nomem;
	} else if (rtm->rtm_src_len)
		if (!(src = nl_addr_alloc(0)))
			goto errout_nomem;

	if (src) {
		nl_addr_set_prefixlen(src, rtm->rtm_src_len);
		rtnl_route_set_src(route, src);
		nl_addr_put(src);
	}

	if (tb[RTA_TABLE])
		rtnl_route_set_table(route, nla_get_u32(tb[RTA_TABLE]));

	if (tb[RTA_IIF])
		rtnl_route_set_iif(route, nla_get_u32(tb[RTA_IIF]));

	if (tb[RTA_PRIORITY])
		rtnl_route_set_priority(route, nla_get_u32(tb[RTA_PRIORITY]));

	if (tb[RTA_PREFSRC]) {
		if (!(addr = nl_addr_alloc_attr(tb[RTA_PREFSRC], family)))
			goto errout_nomem;
		rtnl_route_set_pref_src(route, addr);
		nl_addr_put(addr);
	}

	if (tb[RTA_METRICS]) {
		struct nlattr *mtb[RTAX_MAX + 1];
		int i;

		err = nla_parse_nested(mtb, RTAX_MAX, tb[RTA_METRICS], NULL);
		if (err < 0)
			goto errout;

		for (i = 1; i <= RTAX_MAX; i++) {
			if (mtb[i] && nla_len(mtb[i]) >= sizeof(uint32_t)) {
				uint32_t m = nla_get_u32(mtb[i]);
				if (rtnl_route_set_metric(route, i, m) < 0)
					goto errout;
			}
		}
	}

	if (tb[RTA_MULTIPATH])
		if ((err = parse_multipath(route, tb[RTA_MULTIPATH])) < 0)
			goto errout;

	if (tb[RTA_CACHEINFO]) {
		nla_memcpy(&route->rt_cacheinfo, tb[RTA_CACHEINFO],
			   sizeof(route->rt_cacheinfo));
		route->ce_mask |= ROUTE_ATTR_CACHEINFO;
	}

	if (tb[RTA_OIF]) {
		if (!old_nh && !(old_nh = rtnl_route_nh_alloc()))
			goto errout;

		rtnl_route_nh_set_ifindex(old_nh, nla_get_u32(tb[RTA_OIF]));
	}

	if (tb[RTA_GATEWAY]) {
		if (!old_nh && !(old_nh = rtnl_route_nh_alloc()))
			goto errout;

		if (!(addr = nl_addr_alloc_attr(tb[RTA_GATEWAY], family)))
			goto errout_nomem;

		rtnl_route_nh_set_gateway(old_nh, addr);
		nl_addr_put(addr);
	}

	if (tb[RTA_FLOW]) {
		if (!old_nh && !(old_nh = rtnl_route_nh_alloc()))
			goto errout;

		rtnl_route_nh_set_realms(old_nh, nla_get_u32(tb[RTA_FLOW]));
	}

	if (old_nh) {
		rtnl_route_nh_set_flags(old_nh, rtm->rtm_flags & 0xff);
		if (route->rt_nr_nh == 0) {
			/* If no nexthops have been provided via RTA_MULTIPATH
			 * we add it as regular nexthop to maintain backwards
			 * compatibility */
			rtnl_route_add_nexthop(route, old_nh);
			old_nh = NULL;
		} else {
			/* Kernel supports new style nexthop configuration,
			 * verify that it is a duplicate and discard nexthop. */
			struct rtnl_nexthop *first;

			first = nl_list_first_entry(&route->rt_nexthops,
						    struct rtnl_nexthop,
						    rtnh_list);
			if (!first)
				BUG();

			if (rtnl_route_nh_compare(old_nh, first,
						  old_nh->ce_mask, 0)) {
				err = -NLE_INVAL;
				goto errout;
			}

			rtnl_route_nh_free(old_nh);
			old_nh = NULL;
		}
	}

	*result = route;
	return 0;

errout:
	if (old_nh)
		rtnl_route_nh_free(old_nh);
	rtnl_route_put(route);
	return err;

errout_nomem:
	err = -NLE_NOMEM;
	goto errout;
}

static int mroute_event_filter(struct nl_cache *cache, struct nl_object *obj)
{
	struct rtnl_route *route = (struct rtnl_route *) obj;

	if (route->rt_family != RTNL_FAMILY_IPMR ||
	    (route->rt_flags & RTNH_F_UNRESOLVED))
		return NL_SKIP;

	return NL_OK;
}

static int mroute_msg_match(struct nl_cache_ops *ops, struct nlmsghdr *nlh)
{
	struct rtmsg *rtm = nlmsg_data(nlh);
	int family = rtm->rtm_family;
	int i;

	for (i = 0; ops->co_groups[i].ag_family >= 0; i++) {
		if (ops->co_groups[i].ag_family == family)
			return 1;
	}

	return 0;
}

static int mroute_msg_parser(struct nl_cache_ops *ops, struct sockaddr_nl *who,
			    struct nlmsghdr *nlh, struct nl_parser_param *pp)
{
	struct rtnl_route *route;
	int err;

	if ((err = rtnl_mroute_parse(nlh, &route)) < 0)
		return err;

	err = pp->pp_cb((struct nl_object *) route, pp);

	rtnl_route_put(route);
	return err;
}

static int mroute_request_update(struct nl_cache *c, struct nl_sock *h)
{
	struct rtmsg rhdr = {
		.rtm_family = c->c_iarg1,
	};

	return nl_send_simple(h, RTM_GETROUTE, NLM_F_DUMP, &rhdr, sizeof(rhdr));
}

static struct nl_af_group mroute_groups[] = {
	{ RTNL_FAMILY_IPMR,	RTNLGRP_IPV4_MROUTE},
	{ END_OF_GROUP_LIST },
};

static struct nl_cache_ops rtnl_mroute_ops = {
	.co_name		= "route/mroute",
	.co_hdrsize		= sizeof(struct rtmsg),
	.co_msgtypes		= {
					{ RTM_NEWROUTE, NL_ACT_NEW, "new" },
					{ RTM_DELROUTE, NL_ACT_DEL, "del" },
					{ RTM_GETROUTE, NL_ACT_GET, "get" },
					END_OF_MSGTYPES_LIST,
				  },
	.co_protocol		= NETLINK_ROUTE,
	.co_groups		= mroute_groups,
	.co_request_update	= mroute_request_update,
	.co_msg_match		= mroute_msg_match,
	.co_msg_parser		= mroute_msg_parser,
	.co_event_filter        = mroute_event_filter,
	.co_obj_ops		= &mroute_obj_ops,
	.co_hash_size		= 16384,
};

static void __init mroute_init(void)
{
	nl_cache_mngt_register(&rtnl_mroute_ops);
}

static void __exit mroute_exit(void)
{
	nl_cache_mngt_unregister(&rtnl_mroute_ops);
}

/** @} */
