Apparently working version, but still ugly.
[yaid] / ident.c
diff --git a/ident.c b/ident.c
index ce4ba29..ae3cf13 100644 (file)
--- a/ident.c
+++ b/ident.c
 
 /*----- Header files ------------------------------------------------------*/
 
-#include <ctype.h>
-#include <errno.h>
-#include <stdio.h>
-#include <string.h>
-#include <string.h>
-
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <arpa/inet.h>
-
-#include <syslog.h>
-
-#include <mLib/dstr.h>
-
-/*----- Data structures ---------------------------------------------------*/
-
-union addr {
-  struct in_addr ipv4;
-  struct in6_addr ipv6;
-};
-
-struct socket {
-  union addr addr;
-  int port;
-};
-
-enum { L, R, NDIR };
-
-struct query {
-  int af;
-  struct socket s[NDIR];
-} query;
-
-#define RESPONSE(_)                                                    \
-  _(ERROR, U(error, unsigned))                                         \
-  _(UID, U(uid, uid_t))                                                        \
-  _(NAT, U(nat, struct socket))
-
-#define ERROR(_)                                                       \
-  _(INVPORT, "INVALID-PORT")                                           \
-  _(NOUSER, "NO-USER")                                                 \
-  _(HIDDEN, "HIDDEN-USER")                                             \
-  _(UNKNOWN, "UNKNOWN-ERROR")
-
-enum {
-#define DEFENUM(err, tok) E_##err,
-  ERROR(DEFENUM)
-#undef DEFENUM
-  E_LIMIT
-};
-
-enum {
-#define DEFENUM(what, branch) R_##what,
-  RESPONSE(DEFENUM)
-#undef DEFENUM
-  R_LIMIT
-};
-
-struct response {
-  unsigned what;
-  union {
-#define DEFBRANCH(WHAT, branch) branch
-#define U(memb, ty) ty memb;
-#define N
-    RESPONSE(DEFBRANCH)
-#undef U
-#undef N
-#undef DEFBRANCH
-  } u;
-};
+#include "yaid.h"
 
 /*----- Static variables --------------------------------------------------*/
 
-static const char *errtok[] = {
+const char *const errtok[] = {
 #define DEFTOK(err, tok) tok,
   ERROR(DEFTOK)
 #undef DEFTOK
 };
 
 static int parseaddr4(char **pp, union addr *a)
-  { a->ipv4.s_addr = strtoul(*pp, (char **)pp, 16); return (0); }
+  { a->ipv4.s_addr = strtoul(*pp, pp, 16); return (0); }
 
 static int addreq4(const union addr *a, const union addr *aa)
   { return a->ipv4.s_addr == aa->ipv4.s_addr; }
 
+static int parseaddr6(char **pp, union addr *a)
+{
+  int i, j;
+  unsigned long y;
+  char *p = *pp;
+  unsigned x;
+
+  for (i = 0; i < 4; i++) {
+    y = 0;
+    for (j = 0; j < 8; j++) {
+      if ('0' <= *p && *p <= '9') x = *p - '0';
+      else if ('a' <= *p && *p <= 'f') x = *p - 'a'+ 10;
+      else if ('A' <= *p && *p <= 'F') x = *p - 'A'+ 10;
+      else return (-1);
+      y = (y << 4) | x;
+      p++;
+    }
+    a->ipv6.s6_addr32[i] = y;
+  }
+  *pp = p;
+  return (0);
+}
+
+static int addreq6(const union addr *a, const union addr *b)
+  { return !memcmp(a->ipv6.s6_addr, b->ipv6.s6_addr, 16); }
+
 static const struct addrfamily {
   int af;
   const char *procfile;
@@ -119,43 +75,82 @@ static const struct addrfamily {
   int (*addreq)(const union addr *a, const union addr *aa);
 } addrfamilytab[] = {
   { AF_INET, "/proc/net/tcp", parseaddr4, addreq4 },
-  { AF_INET6, "/proc/net/tcp6", /*parseaddr6*/ },
+  { AF_INET6, "/proc/net/tcp6", parseaddr6, addreq6 },
   { -1 }
 };
 
 /*----- Main code ---------------------------------------------------------*/
 
-static void dputsock(dstr *d, int af, const struct socket *s)
+static int sockeq(const struct addrfamily *af,
+                 const struct socket *sa, const struct socket *sb)
+  { return (af->addreq(&sa->addr, &sb->addr) && sa->port == sb->port); }
+
+int get_default_gw(int af, union addr *a)
 {
-  char buf[INET6_ADDRSTRLEN];
+  int fd;
+  char buf[32768];
+  struct nlmsghdr *nlmsg;
+  struct rtgenmsg *rtgen;
+  const struct rtattr *rta;
+  const struct rtmsg *rtm;
+  ssize_t n, nn;
+  int rc = 0;
+  static unsigned long seq = 0x48b4aec4;
+
+  if ((fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE)) < 0)
+    die(1, "failed to create netlink socket: %s", strerror(errno));
+
+  nlmsg = (struct nlmsghdr *)buf;
+  assert(NLMSG_SPACE(sizeof(*rtgen)) < sizeof(buf));
+  nlmsg->nlmsg_len = NLMSG_LENGTH(sizeof(*rtgen));
+  nlmsg->nlmsg_type = RTM_GETROUTE;
+  nlmsg->nlmsg_flags = NLM_F_REQUEST | NLM_F_ROOT;
+  nlmsg->nlmsg_seq = ++seq;
+  nlmsg->nlmsg_pid = 0;
+
+  rtgen = (struct rtgenmsg *)NLMSG_DATA(nlmsg);
+  rtgen->rtgen_family = af;
+
+  if (write(fd, nlmsg, nlmsg->nlmsg_len) < 0)
+    die(1, "failed to send RTM_GETROUTE request: %s", strerror(errno));
 
-  inet_ntop(af, &s->addr, buf, sizeof(buf));
-  if (af != AF_INET6) dstr_puts(d, buf);
-  else { dstr_putc(d, '['); dstr_puts(d, buf); dstr_putc(d, ']'); }
-  dstr_putf(d, ":%d", s->port);
-}
+  for (;;) {
+    if ((n = read(fd, buf, sizeof(buf))) < 0)
+      die(1, "failed to read RTM_GETROUTE response: %s", strerror(errno));
+    nlmsg = (struct nlmsghdr *)buf;
+    if (nlmsg->nlmsg_seq != seq) continue;
+    assert(nlmsg->nlmsg_flags & NLM_F_MULTI);
+
+    for (; NLMSG_OK(nlmsg, n); nlmsg = NLMSG_NEXT(nlmsg, n)) {
+      if (nlmsg->nlmsg_type == NLMSG_DONE) goto done;
+      if (nlmsg->nlmsg_type != RTM_NEWROUTE) continue;
+      rtm = (const struct rtmsg *)NLMSG_DATA(nlmsg);
+
+      if (rtm->rtm_family != af ||
+         rtm->rtm_dst_len > 0 ||
+         rtm->rtm_src_len > 0 ||
+         rtm->rtm_type != RTN_UNICAST ||
+         rtm->rtm_scope != RT_SCOPE_UNIVERSE ||
+         rtm->rtm_tos != 0)
+       continue;
 
-static void logmsg(const struct query *q, int prio, const char *msg, ...)
-{
-  va_list ap;
-  dstr d = DSTR_INIT;
+      for (rta = RTM_RTA(rtm), nn = RTM_PAYLOAD(nlmsg);
+          RTA_OK(rta, nn); rta = RTA_NEXT(rta, nn)) {
+       if (rta->rta_type == RTA_GATEWAY) {
+         assert(RTA_PAYLOAD(rta) <= sizeof(*a));
+         memcpy(a, RTA_DATA(rta), RTA_PAYLOAD(rta));
+         rc = 1;
+       }
+      }
+    }
+  }
 
-  va_start(ap, msg);
-  dputsock(&d, q->af, &q->s[L]);
-  dstr_puts(&d, " <-> ");
-  dputsock(&d, q->af, &q->s[R]);
-  dstr_puts(&d, ": ");
-  dstr_vputf(&d, msg, &ap);
-  va_end(ap);
-  fprintf(stderr, "yaid: %s\n", d.buf);
-  dstr_destroy(&d);
+done:
+  close(fd);
+  return (rc);
 }
 
-static int sockeq(const struct addrfamily *af,
-                 const struct socket *sa, const struct socket *sb)
-  { return (af->addreq(&sa->addr, &sb->addr) && sa->port == sb->port); }
-
-void identify(const struct query *q, struct response *r)
+void identify(struct query *q)
 {
   const struct addrfamily *af;
   FILE *fp = 0;
@@ -163,6 +158,7 @@ void identify(const struct query *q, struct response *r)
   char *p, *pp;
   struct socket s[4];
   int i;
+  int gwp = 0;
   unsigned fl;
 #define F_SADDR 1u
 #define F_SPORT 2u
@@ -180,6 +176,10 @@ void identify(const struct query *q, struct response *r)
   goto err_unk;
 found_af:;
 
+  if (get_default_gw(q->af, &s[0].addr) &&
+      af->addreq(&s[0].addr, &q->s[R].addr))
+    gwp = 1;
+
   if ((fp = fopen(af->procfile, "r")) == 0) {
     logmsg(q, LOG_ERR, "failed to open `%s' for reading: %s",
           af->procfile, strerror(errno));
@@ -242,13 +242,13 @@ found_af:;
       if (af->parseaddr(&p, &s[0].addr)) goto next_row;
       if (*p != ':') break; p++;
       s[0].port = strtoul(p, 0, 16);
-      /* FIXME: accept forwarded queries from NAT */
-      if (!sockeq(af, &q->s[i], &s[0])) goto next_row;
-      else continue;
+      if (!sockeq(af, &q->s[i], &s[0]) &&
+         (i != R || !gwp || q->s[R].port != s[0].port))
+       goto next_row;
     }
     if (uid != -1) {
-      r->what = R_UID;
-      r->u.uid = uid;
+      q->resp = R_UID;
+      q->u.uid = uid;
       goto done;
     }
   next_row:;
@@ -331,8 +331,8 @@ found_af:;
       if (!sockeq(af, &s[i^1], &s[i^2]) ||
          !sockeq(af, &s[i^1], &q->s[R]))
        continue;
-      r->what = R_NAT;
-      r->u.nat = s[i^3];
+      q->resp = R_NAT;
+      q->u.nat = s[i^3];
       goto done;
     }
 
@@ -341,53 +341,20 @@ found_af:;
             strerror(errno));
       goto err_unk;
     }
+    logmsg(q, LOG_ERR, "connection not found");
   }
 
 #undef NEXTFIELD
 
 err_nouser:
-  r->what = R_ERROR;
-  r->u.error = E_NOUSER;
+  q->resp = R_ERROR;
+  q->u.error = E_NOUSER;
   goto done;
 err_unk:
-  r->what = R_ERROR;
-  r->u.error = E_UNKNOWN;
+  q->resp = R_ERROR;
+  q->u.error = E_UNKNOWN;
 done:
   dstr_destroy(&d);
 }
 
 /*----- That's all, folks -------------------------------------------------*/
-
-int main(int argc, char *argv[])
-{
-  struct query q;
-  struct response r;
-  char buf[INET6_ADDRSTRLEN];
-
-  q.af = AF_INET;
-  inet_pton(AF_INET, argv[1], &q.s[L].addr.ipv4);
-  q.s[L].port = atoi(argv[2]);
-  inet_pton(AF_INET, argv[3], &q.s[R].addr.ipv4);
-  q.s[R].port = atoi(argv[4]);
-
-  identify(&q, &r);
-
-  switch (r.what) {
-    case R_UID:
-      printf("uid %d\n", r.u.uid);
-      break;
-    case R_ERROR:
-      if (r.u.error < E_LIMIT) printf("error %s\n", errtok[r.u.error]);
-      else printf("error E%u\n", r.u.error);
-      break;
-    case R_NAT:
-      inet_ntop(q.af, &r.u.nat.addr, buf, sizeof(buf));
-      printf("nat -> %s:%d\n", buf, r.u.nat.port);
-      break;
-    default:
-      printf("unknown response\n");
-      break;
-  }
-
-  return (0);
-}