/*
** k_dummy.c - A dummy (demonstration) kernel access module
**
** Copyright (c) 1997 Peter Eriksson <pen@lysator.liu.se>
**
** This program is free software; you can redistribute it and/or
** modify it as you wish - as long as you don't claim that you wrote
** it.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/

#include "config.h"

#include <sys/types.h>
#include <sys/param.h>
#include <sys/sysctl.h>

#include <nlist.h>
#include <syslog.h>

#define _KERNEL
#include <sys/file.h>
#undef _KERNEL

#include <kvm.h>

#include <fcntl.h>

#include <net/route.h>

#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#ifdef INET6
#include <netinet6/ip6.h>
#endif
#include <netinet/in_pcb.h>
#ifdef INET6
#include <netinet6/in6_pcb.h>
#endif

#include "pidentd.h"

struct kainfo {
	kvm_t *kd;
	int nfile;
	struct nlist nl[3];
};

static int getbuf(struct kainfo *, u_long, char *, u_int, char *);
static struct socket *getlist(struct kainfo *,
			      struct inpcbtable *, struct inpcbtable *,
			      struct in_addr *, int, struct in_addr *, int);
#ifdef INET6
static struct socket *getlist6(struct kainfo *, struct in6pcb *,
			       struct in6_addr *, int, struct in6_addr *, int);
#endif

/*
** ka_init should verify that the binary is running on a
** supported OS version (in most cases it should verify
** that it is running on exactly the same version as it was
** compiled on).
*/
int
ka_init(void)
{
    /* The kvm routines are not MT-Safe! */
    kernel_threads = 1;
    return 0;
}



/*
** ka_open should open any kernel file descriptors or other
** resources needed to access it, put it into some dynamically
** allocated structure and store it into the 'misc' pointer.
**
** It should return 0 if all was OK, else -1
*/
int
ka_open(void **misc)
{
    struct kainfo *kp = s_malloc(sizeof(struct kainfo));

    /*
    ** Open the kernel memory device
    */
    if ((kp->kd = (kvm_t *)kvm_openfiles(NULL, NULL,
					 NULL, O_RDONLY, NULL)) == NULL) {
	syslog(LOG_ERR, "kvm_open: %m");
	s_free(kp);
	return -1;
    }

#define N_FILE 0
#define N_NFILE 1
#define N_TCBTABLE 2
#ifdef INET6
#define N_TCB6 3
#endif
    kp->nl[N_FILE].n_name = "_filehead";
    kp->nl[N_NFILE].n_name = "_nfiles";
    kp->nl[N_TCBTABLE].n_name = "_tcbtable";
#ifdef INET6
    kp->nl[N_TCB6].n_name = "_tcb6";
    kp->nl[4].n_name = "";
#else
    kp->nl[3].n_name = "";
#endif
    /*
    ** Extract offsets to the needed variables in the kernel
    */
    if (kvm_nlist(kp->kd, kp->nl) < 0) {
	syslog(LOG_ERR, "kvm_nlist: %m");
	kvm_close(kp->kd);
	s_free(kp);
	return -1;
    }

    *misc = (void *) kp;
    return 0;
}



/*
** Get a piece of kernel memory with error handling.
** Returns 1 if call succeeded, else 0 (zero).
*/
static int
getbuf(struct kainfo *kp, u_long addr, char *buf, u_int len, char *what)
{
    if (kvm_read(kp->kd, addr, buf, len) < 0) {
	syslog(LOG_ERR, "getbuf: kvm_read(%08x, %d) - %s : %m",
	       addr, len, what);
	return 0;
    }
    return 1;
}



/*
** Traverse the inpcb list until a match is found.
** Returns NULL if no match.
*/
static struct socket *
getlist(struct kainfo *kp,
	struct inpcbtable *tcbtablep, struct inpcbtable *ktcbtablep,
	struct in_addr *faddr, int fport, struct in_addr *laddr, int lport)
{
    struct inpcb *kpcbp, pcb;

    if (!tcbtablep)
	return NULL;
 
    for (kpcbp = tcbtablep->inpt_queue.cqh_first;
	 kpcbp != (struct inpcb *)ktcbtablep;
	 kpcbp = pcb.inp_queue.cqe_next) {
	if (!getbuf(kp, (long) kpcbp,
		    (char *)&pcb, sizeof(struct inpcb), "tcb"))
	    break;
	if (pcb.inp_faddr.s_addr == faddr->s_addr &&
	    pcb.inp_laddr.s_addr == laddr->s_addr &&
	    pcb.inp_fport        == fport &&
	    pcb.inp_lport        == lport )
	    return pcb.inp_socket;
    }
    return NULL;
}



#ifdef INET6
static struct socket *
getlist6(struct kainfo *kp, struct in6pcb *pcbp,
	 struct in6_addr *faddr, int fport,
	 struct in6_addr *laddr, int lport)
{
    struct in6pcb *head;

    if (!pcbp)
	return NULL;

    head = pcbp->in6p_prev;
    do {
	if (!memcmp(pcbp->in6p_faddr.s6_addr, faddr,
		    sizeof(struct in6_addr)) &&
	    !memcmp(pcbp->in6p_laddr.s6_addr, laddr,
		    sizeof(struct in6_addr)) &&
	    pcbp->in6p_fport == fport &&
	    pcbp->in6p_lport == lport)
	    return pcbp->in6p_socket;
    } while (pcbp->in6p_next != head &&
	     getbuf(kp, (long)pcbp->in6p_next,
		    (char *)pcbp, sizeof(struct in6pcb), "tcblist"));

    return NULL;
}
#endif



/*
** ka_lookup gets called when a request thread wants to
** do a kernel lookup.
**
** The pointer returned from ka_init() is passed as the "vp" parameter.
** The local and remote address and port is available in the "kp"
** parameter.
**
** The function should set both the effective uid and real uid (if
** either one isn't available, return -1 in that return variable)
** variables in the "struct kernel" argument.
**
** This function should return a 1 if the lookup was successful,
** a zero if the connection wasn't found or a -1 in case an error
** occured (the call to ka_lookup() will be retried if -1 
** is returned a configurable number of times, but will fail
** immediately in case of a 0).
*/
int
ka_lookup(void *vp, struct kernel *ke)
{
    long addr;
    struct socket *sockp;
    int i, mib[2];
    struct ucred ucb;
#ifdef INET6
    union sockunion *fsin = &ke->remote;
    union sockunion *lsin = &ke->local;
#else
    struct in_addr *faddr = &ke->remote.sin_addr;
    int fport = ke->remote.sin_port;
    struct in_addr *laddr = &ke->local.sin_addr;
    int lport = ke->local.sin_port;
#endif
    struct kainfo *kp = vp;
    struct file *xfile;
    int nfile;
    struct inpcbtable tcbtable;
#ifdef INET6
    struct in6pcb tcb6;
#endif


    /* -------------------- FILE DESCRIPTOR TABLE -------------------- */
    if (!getbuf(kp, kp->nl[N_NFILE].n_value,
		(char *)&nfile, sizeof(nfile), "nfile"))
	return -1;
  
    if (!getbuf(kp, kp->nl[N_FILE].n_value,
		(char *)&addr, sizeof(addr), "&file"))
	return -1;

    {
	size_t siz;
	int rv;

	mib[0] = CTL_KERN;
	mib[1] = KERN_FILE;
	if ((rv = sysctl(mib, 2, NULL, &siz, NULL, 0)) == -1) {
	    syslog(LOG_ERR, "k_getuid: sysctl 1 (%d)", rv);
	    return -1;
	}
	xfile = (struct file *)malloc(siz);
	if (!xfile)
	    syslog(LOG_ERR, "k_getuid: malloc(%ld)", (u_long)siz);
	if ((rv = sysctl(mib, 2, xfile, &siz, NULL, 0)) == -1) {
	    syslog(LOG_ERR, "k_getuid: sysctl 2 (%d)", rv);
	    return -1;
	}
	xfile = (struct file *)((char *)xfile + sizeof(filehead));
    }
  
    /* -------------------- TCP PCB LIST -------------------- */
#ifdef INET6
    if (fsin->su_family == AF_INET) {
	if (!getbuf(kp, kp->nl[N_TCBTABLE].n_value,
		    (char*)&tcbtable, sizeof(tcbtable), "tcbtable"))
	    return -1;
	sockp = getlist(kp, &tcbtable,
			(struct inpcbtable *)kp->nl[N_TCBTABLE].n_value,
			&fsin->su_sin_addr, fsin->su_port,
			&lsin->su_sin_addr, lsin->su_port);
    } else if (IN6_IS_ADDR_V4MAPPED(&fsin->su_sin6_addr)) {
	if (!getbuf(kp, kp->nl[N_TCBTABLE].n_value,
		    (char*)&tcbtable, sizeof(tcbtable), "tcbtable"))
	    return -1;
	sockp = getlist(kp, &tcbtable,
			(struct inpcbtable *)kp->nl[N_TCBTABLE].n_value,
			(struct in_addr *)&fsin->su_sin6_addr.s6_addr32[3],
			fsin->su_port,
			(struct in_addr *)&lsin->su_sin6_addr.s6_addr32[3],
			lsin->su_port);
    } else {
	if (!getbuf(kp, kp->nl[N_TCB6].n_value,
		    (char*)&tcb6, sizeof(tcb6), "tcb6"))
	    return -1;
	tcb6.in6p_prev = (struct in6pcb *)kp->nl[N_TCB6].n_value;
	sockp = getlist6(kp, &tcb6,
			 &fsin->su_sin6_addr, fsin->su_port,
			 &lsin->su_sin6_addr, lsin->su_port);
    }
#else
    if (!getbuf(kp, kp->nl[N_TCBTABLE].n_value,
		(char*)&tcbtable, sizeof(tcbtable), "tcbtable"))
	return -1;
    sockp = getlist(kp, &tcbtable, (struct inpcbtable *)nl[N_TCBTABLE].n_value,
		    faddr, fport, laddr, lport);
#endif

    if (!sockp)
	return -1;

    /*
    ** Locate the file descriptor that has the socket in question
    ** open so that we can get the 'ucred' information
    */
    for (i = 0; i < nfile; i++) {
	if (xfile[i].f_count == 0)
	    continue;
	if (xfile[i].f_type == DTYPE_SOCKET &&
	    (struct socket *) xfile[i].f_data == sockp) {
	    if (!getbuf(kp, (long)xfile[i].f_cred,
			(char *)&ucb, sizeof(ucb), "ucb"))
		return -1;
	    ke->ruid = NO_UID;
	    ke->euid = ucb.cr_uid;
	    return 1;
	}
    }
    return -1;
}
