General spring-cleaning. Most of the code is pretty nice now.
authorMark Wooding <mdw@distorted.org.uk>
Sat, 20 Oct 2012 17:30:21 +0000 (18:30 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Wed, 24 Oct 2012 09:24:25 +0000 (10:24 +0100)
Makefile.am
addr.c
linux.c
policy.c
yaid.c
yaid.h

index ef046fb..a89b137 100644 (file)
@@ -40,6 +40,7 @@ yaid_SOURCES           =
 EXTRA_yaid_SOURCES      =
 yaid_LDADD              = $(mLib_LIBS)
 
+yaid_SOURCES           += yaid.h
 yaid_SOURCES           += yaid.c
 yaid_SOURCES           += addr.c
 yaid_SOURCES           += policy.c
diff --git a/addr.c b/addr.c
index e72a7fd..bfbbabf 100644 (file)
--- a/addr.c
+++ b/addr.c
@@ -114,10 +114,12 @@ static const union addr any_ipv6 = { .ipv6 = IN6ADDR_ANY_INIT };
 
 /*----- General utilities -------------------------------------------------*/
 
+/* Answer whether the sockets SA and SB are equal. */
 int sockeq(const struct addrops *ao,
           const struct socket *sa, const struct socket *sb)
   { return (ao->addreq(&sa->addr, &sb->addr) && sa->port == sb->port); }
 
+/* Write a textual description of S to the string D. */
 void dputsock(dstr *d, const struct addrops *ao, const struct socket *s)
 {
   char buf[ADDRLEN];
diff --git a/linux.c b/linux.c
index 9d83a90..d7d8475 100644 (file)
--- a/linux.c
+++ b/linux.c
 
 #include "yaid.h"
 
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
 /*----- Static variables --------------------------------------------------*/
 
-static FILE *natfp;
+static FILE *natfp;                    /* File handle for NAT table */
 
 /*----- Address-type operations -------------------------------------------*/
 
@@ -56,6 +59,7 @@ static int parseaddr_ipv6(char **pp, union addr *a)
   char *p = *pp;
   unsigned x;
 
+  /* The format is byteswapped in a really annoying way. */
   for (i = 0; i < 4; i++) {
     y = 0;
     for (j = 0; j < 8; j++) {
@@ -81,6 +85,9 @@ ADDRTYPES(DEFOPSYS)
 
 /*----- Main code ---------------------------------------------------------*/
 
+/* Store in A the default gateway address for the given address family.
+ * Return zero on success, or nonzero on error.
+ */
 static int get_default_gw(int af, union addr *a)
 {
   int fd;
@@ -93,9 +100,13 @@ static int get_default_gw(int af, union addr *a)
   int rc = 0;
   static unsigned long seq = 0x48b4aec4;
 
+  /* Open a netlink socket for interrogating the kernel. */
   if ((fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE)) < 0)
     die(1, "failed to create netlink socket: %s", strerror(errno));
 
+  /* We want to read the routing table.  There doesn't seem to be a good way
+   * to do this without just crawling through the whole thing.
+   */
   nlmsg = (struct nlmsghdr *)buf;
   assert(NLMSG_SPACE(sizeof(*rtgen)) < sizeof(buf));
   nlmsg->nlmsg_len = NLMSG_LENGTH(sizeof(*rtgen));
@@ -110,28 +121,45 @@ static int get_default_gw(int af, union addr *a)
   if (write(fd, nlmsg, nlmsg->nlmsg_len) < 0)
     die(1, "failed to send RTM_GETROUTE request: %s", strerror(errno));
 
+  /* Now we try to parse the answer. */
   for (;;) {
+
+    /* Not finished yet, so read another chunk of answer. */
     if ((n = read(fd, buf, sizeof(buf))) < 0)
       die(1, "failed to read RTM_GETROUTE response: %s", strerror(errno));
+
+    /* Start at the beginning of the response. */
     nlmsg = (struct nlmsghdr *)buf;
+
+    /* Make sure this looks plausible.  The precise rules don't appear to be
+     * documented, so it seems advisable to fail messily if my understanding
+     * is wrong.
+     */
     if (nlmsg->nlmsg_seq != seq) continue;
     assert(nlmsg->nlmsg_flags & NLM_F_MULTI);
 
+    /* Work through all of the individual routes. */
     for (; NLMSG_OK(nlmsg, n); nlmsg = NLMSG_NEXT(nlmsg, n)) {
       if (nlmsg->nlmsg_type == NLMSG_DONE) goto done;
       if (nlmsg->nlmsg_type != RTM_NEWROUTE) continue;
       rtm = (const struct rtmsg *)NLMSG_DATA(nlmsg);
 
-      if (rtm->rtm_family != af ||
-         rtm->rtm_dst_len > 0 ||
-         rtm->rtm_src_len > 0 ||
-         rtm->rtm_type != RTN_UNICAST ||
-         rtm->rtm_scope != RT_SCOPE_UNIVERSE ||
-         rtm->rtm_tos != 0)
+      /* If this record doesn't look interesting then skip it. */
+      if (rtm->rtm_family != af ||     /* wrong address family */
+         rtm->rtm_dst_len > 0 ||       /* specific destination */
+         rtm->rtm_src_len > 0 ||       /* specific source  */
+         rtm->rtm_type != RTN_UNICAST || /* not for unicast */
+         rtm->rtm_scope != RT_SCOPE_UNIVERSE || /* wrong scope */
+         rtm->rtm_tos != 0)            /* specific type of service */
        continue;
 
+      /* Trundle through the attributes and find the gateway address. */
       for (rta = RTM_RTA(rtm), nn = RTM_PAYLOAD(nlmsg);
           RTA_OK(rta, nn); rta = RTA_NEXT(rta, nn)) {
+
+       /* Got one.  We're all done.  Except that we should carry on reading
+        * to the end, or something bad will happen.
+        */
        if (rta->rta_type == RTA_GATEWAY) {
          assert(RTA_PAYLOAD(rta) <= sizeof(*a));
          memcpy(a, RTA_DATA(rta), RTA_PAYLOAD(rta));
@@ -146,6 +174,10 @@ done:
   return (rc);
 }
 
+/* Find out who is responsible for the connection described in the query Q.
+ * Write the answer to Q.  Errors are logged and reported via the query
+ * structure.
+ */
 void identify(struct query *q)
 {
   FILE *fp = 0;
@@ -165,22 +197,32 @@ void identify(struct query *q)
   enum { LOC, REM, ST, UID, NFIELD };
   int f, ff[NFIELD];
 
+  /* If we have a default gateway, and it matches the remote address then
+   * this may be a proxy connection from our NAT, so remember this, and don't
+   * inspect the remote addresses in the TCP tables.
+   */
   if (get_default_gw(q->ao->af, &s[0].addr) &&
       q->ao->addreq(&s[0].addr, &q->s[R].addr))
     gwp = 1;
 
+  /* Open the relevant TCP connection table. */
   if ((fp = fopen(q->ao->sys->procfile, "r")) == 0) {
     logmsg(q, LOG_ERR, "failed to open `%s' for reading: %s",
           q->ao->sys->procfile, strerror(errno));
     goto err_unk;
   }
 
+  /* Initially, PP points into a string containing whitespace-separated
+   * fields.  Point P to the next field, null-terminate it, and advance PP
+   * so that we can read the next field in the next call.
+   */
 #define NEXTFIELD do {                                                 \
   for (p = pp; isspace((unsigned char)*p); p++);                       \
   for (pp = p; *pp && !isspace((unsigned char)*pp); pp++);             \
   if (*pp) *pp++ = 0;                                                  \
 } while (0)
 
+  /* Read the header line from the file. */
   if (dstr_putline(&d, fp) == EOF) {
     logmsg(q, LOG_ERR, "failed to read header line from `%s': %s",
           q->ao->sys->procfile,
@@ -188,6 +230,13 @@ void identify(struct query *q)
     goto err_unk;
   }
 
+  /* Now scan the header line to identify which columns the various
+   * interesting fields are in.  Store these in the map `ff'.  Problems:
+   * `tx_queue rx_queue' and `tr tm->when' are both really single columns in
+   * disguise; and the remote address column has a different heading
+   * depending on which address family we're using.  Rather than dispatch,
+   * just recognize both of them.
+   */
   for (i = 0; i < NFIELD; i++) ff[i] = -1;
   pp = d.buf;
   for (f = 0;; f++) {
@@ -205,6 +254,8 @@ void identify(struct query *q)
             strcmp(p, "tm->when") == 0)
       f--;
   }
+
+  /* Make sure that we found all of the fields we actually want. */
   for (i = 0; i < NFIELD; i++) {
     if (ff[i] < 0) {
       logmsg(q, LOG_ERR, "failed to find required fields in `%s'",
@@ -213,11 +264,20 @@ void identify(struct query *q)
     }
   }
 
+  /* Work through the lines in the file. */
   for (;;) {
+
+    /* Read a line, and prepare to scan the fields. */
     DRESET(&d);
     if (dstr_putline(&d, fp) == EOF) break;
     pp = d.buf;
     uid = -1;
+
+    /* Work through the fields.  If an address field fails to match then we
+     * skip this record.  If the state field isn't 1 (`ESTABLISHED') then
+     * skip the record.  If it's the UID, then remember it: if we get all the
+     * way to the end then we've won.
+     */
     for (f = 0;; f++) {
       NEXTFIELD; if (!*p) break;
       if (f == ff[LOC]) { i = L; goto compare; }
@@ -229,13 +289,25 @@ void identify(struct query *q)
       continue;
 
     compare:
+      /* Compare an address (in the current field) with the local or remote
+       * address in the query, as indicated by `i'.  The address field looks
+       * like `ADDR:PORT', where the ADDR is in some mad format which
+       * `sys->parseaddr' knows how to unpick.  If the remote address in the
+       * query is our gateway then don't check the remote address in the
+       * field (but do check the port number).
+       */
       if (q->ao->sys->parseaddr(&p, &s[0].addr)) goto next_row;
       if (*p != ':') break; p++;
       s[0].port = strtoul(p, 0, 16);
-      if (!sockeq(q->ao, &q->s[i], &s[0]) &&
-         (i != R || !gwp || q->s[R].port != s[0].port))
+      if ((i == R && gwp) ?
+           q->s[R].port != s[0].port :
+           !sockeq(q->ao, &q->s[i], &s[0]))
        goto next_row;
     }
+
+    /* We got to the end, and everything matched.  If we found a UID then
+     * we're done.
+     */
     if (uid != -1) {
       q->resp = R_UID;
       q->u.uid = uid;
@@ -244,25 +316,43 @@ void identify(struct query *q)
   next_row:;
   }
 
+  /* We got to the end of the file and didn't find anything. */
   if (ferror(fp)) {
     logmsg(q, LOG_ERR, "failed to read connection table `%s': %s",
           q->ao->sys->procfile, strerror(errno));
     goto err_unk;
   }
 
+  /* If we opened the NAT table file, and we're using IPv4, then check to see
+   * whether we should proxy the connection.  At least the addresses in this
+   * file aren't crazy.
+   */
   if (natfp) {
+
+    /* Start again from the beginning. */
     rewind(natfp);
 
+    /* Read a line at a time. */
     for (;;) {
+
+      /* Read the line. */
       DRESET(&d);
       if (dstr_putline(&d, natfp) == EOF) break;
       pp = d.buf;
 
+      /* Check that this is for the right protocol. */
       NEXTFIELD; if (!*p) break;
       if (strcmp(p, q->ao->sys->nfl3name)) continue;
       NEXTFIELD; if (!*p) break;
       NEXTFIELD; if (!*p) break;
       if (strcmp(p, "tcp") != 0) continue;
+
+      /* Parse the other fields.  Each line has two src/dst pairs, for the
+       * outgoing and incoming directions.  Depending on exactly what kind of
+       * NAT is in use, either the outgoing source or the incoming
+       * destination might be the client we're after.  Collect all of the
+       * addresses and sort out the mess later.
+       */
       i = 0;
       fl = 0;
       for (;;) {
@@ -291,6 +381,7 @@ void identify(struct query *q)
 
 #ifdef notdef
       {
+       /* Print the record we found. */
        dstr dd = DSTR_INIT;
        dstr_putf(&dd, "%sestab ", (fl & F_ESTAB) ? " " : "!");
        dputsock(&dd, q->ao, &s[0]);
@@ -305,15 +396,28 @@ void identify(struct query *q)
       }
 #endif
 
+      /* If the connection isn't ESTABLISHED then skip it. */
       if (!(fl & F_ESTAB)) continue;
 
+      /* Now we try to piece together what's going on.  One of these
+       * addresses will be us.  So let's just try to find it.
+       */
       for (i = 0; i < 4; i++)
        if (sockeq(q->ao, &s[i], &q->s[L])) goto found_local;
       continue;
+
     found_local:
+      /* So address `i' is us.  In that case, we expect the other address in
+       * the same direction, and the same address in the opposite direction,
+       * to match each other and be the remote address in the query.
+       */
       if (!sockeq(q->ao, &s[i^1], &s[i^2]) ||
          !sockeq(q->ao, &s[i^1], &q->s[R]))
        continue;
+
+      /* We win.  The remaining address must be the client host.  We should
+       * proxy this query.
+       */
       q->resp = R_NAT;
       q->u.nat = s[i^3];
       goto done;
@@ -329,18 +433,26 @@ void identify(struct query *q)
 
 #undef NEXTFIELD
 
+  /* We didn't find a match anywhere.  How unfortunate. */
   logmsg(q, LOG_NOTICE, "connection not found");
   q->resp = R_ERROR;
   q->u.error = E_NOUSER;
   goto done;
+
 err_unk:
+  /* Something went wrong and the protocol can't express what.  We should
+   * have logged what the problem actually was.
+   */
   q->resp = R_ERROR;
   q->u.error = E_UNKNOWN;
+
 done:
+  /* All done. */
   dstr_destroy(&d);
   if (fp) fclose(fp);
 }
 
+/* Initialize the system-specific code. */
 void init_sys(void)
 {
   if ((natfp = fopen("/proc/net/nf_conntrack", "r")) == 0 &&
index 80e2bfd..06475f3 100644 (file)
--- a/policy.c
+++ b/policy.c
 
 #include "yaid.h"
 
-/*----- Main code ---------------------------------------------------------*/
+/*----- Memory management -------------------------------------------------*/
 
-/* syntax: addrpat portpat addrpar portpat policy
- *
- * local address/port first, then remote
- * addrpat ::= addr [/ len]
- * portpat ::= num | num - num | *
- * policy ::= user policy* | token | name | deny | hide |
+/* Initialize a policy structure.  In this state, it doesn't actually have
+ * any resources allocated (so can be simply discarded) but it's safe to free
+ * (using `free_policy').
  */
-
 void init_policy(struct policy *p) { p->act.act = A_LIMIT; }
 
+/* Free an action structure, resetting it to a safe state.  This function is
+ * idempotent.
+ */
 static void free_action(struct action *a)
 {
   switch (a->act) {
@@ -50,9 +49,14 @@ static void free_action(struct action *a)
   a->act = A_LIMIT;
 }
 
+/* Free a policy structure, resetting it to its freshly-initialized state.
+ * This function is idempotent.
+ */
 void free_policy(struct policy *p)
   { free_action(&p->act); }
 
+/*----- Diagnostics -------------------------------------------------------*/
+
 static void print_addrpat(const struct addrops *ao, const struct addrpat *ap)
 {
   char buf[ADDRLEN];
@@ -100,6 +104,7 @@ static void print_action(const struct action *act)
   }
 }
 
+/* Print a policy rule to standard output. */
 void print_policy(const struct policy *p)
 {
   print_sockpat(p->ao, &p->sp[L]); putchar(' ');
@@ -107,9 +112,13 @@ void print_policy(const struct policy *p)
   print_action(&p->act); putchar('\n');
 }
 
+/*----- Matching ----------------------------------------------------------*/
+
+/* Return true if the port matches the pattern. */
 static int match_portpat(const struct portpat *pp, unsigned port)
   { return (pp->lo <= port && port <= pp->hi); }
 
+/* Return true if the socket matches the pattern. */
 static int match_sockpat(const struct addrops *ao,
                         const struct sockpat *sp, const struct socket *s)
 {
@@ -117,6 +126,7 @@ static int match_sockpat(const struct addrops *ao,
          match_portpat(&sp->port, s->port));
 }
 
+/* Return true if the query matches the patterns in the policy rule. */
 int match_policy(const struct policy *p, const struct query *q)
 {
   return ((!p->ao || p->ao == q->ao) &&
@@ -124,6 +134,9 @@ int match_policy(const struct policy *p, const struct query *q)
          match_sockpat(q->ao, &p->sp[R], &q->s[R]));
 }
 
+/*----- Parsing -----------------------------------------------------------*/
+
+/* Advance FP to the next line. */
 static void nextline(FILE *fp)
 {
   for (;;) {
@@ -132,33 +145,54 @@ static void nextline(FILE *fp)
   }
 }
 
+/* Scan a whitespace-separated token from FP, writing it to BUF.  The token
+ * must fit in a buffer of size SZ, including a terminating null.  Return
+ * an appropriate T_* error code.
+ */
 static int scan(FILE *fp, char *buf, size_t sz)
 {
   int ch;
 
 skip_ws:
+  /* Before we start grabbing a token proper, find out what's in store. */
   ch = getc(fp);
   switch (ch) {
+
     case '\n':
     newline:
+      /* Found a newline.  Leave it where it is and report it. */
       ungetc(ch, fp);
       return (T_EOL);
+
     case EOF:
     eof:
+      /* Found end-of-file, or an I/O error.  Return an appropriate code. */
       return (ferror(fp) ? T_ERROR : T_EOF);
+
     case '#':
+      /* Found a comment.  Consume it, and continue appropriately: it must
+       * be terminated either by a newline or end-of-file.
+       */
       for (;;) {
        ch = getc(fp);
        if (ch == '\n') goto newline;
        else if (ch == EOF) goto eof;
       }
+
     default:
+      /* Whitespace means we just continue around.  Anything else and we
+       * start snarfing.
+       */
       if (isspace(ch)) goto skip_ws;
       break;
   }
 
   for (;;) {
+
+    /* If there's buffer space left, store the character. */
     if (sz) { *buf++ = ch; sz--; }
+
+    /* Get a new one, and find out what to do about it. */
     ch = getc(fp);
     switch (ch) {
       case '\n':
@@ -173,14 +207,17 @@ skip_ws:
   }
 
 done:
-  if (!sz)
-    return (T_ERROR);
-  else {
-    *buf++ = 0; sz--;
-    return (T_OK);
-  }
+  /* If there's no space for a terminating null then report an error. */
+  if (!sz) return (T_ERROR);
+
+  /* All done. */
+  *buf++ = 0; sz--;
+  return (T_OK);
 }
 
+/* Parse an action name, storing the code in *ACT.  Return an appropriate T_*
+ * code.
+ */
 static int parse_actname(FILE *fp, unsigned *act)
 {
   char buf[32];
@@ -193,6 +230,7 @@ static int parse_actname(FILE *fp, unsigned *act)
   return (T_ERROR);
 }
 
+/* Parse an action, returning a T_* code. */
 static int parse_action(FILE *fp, struct action *act)
 {
   char buf[32];
@@ -200,9 +238,14 @@ static int parse_action(FILE *fp, struct action *act)
   unsigned a;
   unsigned long m;
 
+  /* Collect the action name. */
   if ((t = parse_actname(fp, &a)) != 0) return (t);
+
+  /* Parse parameters, if there are any. */
   switch (a) {
+
     case A_USER:
+      /* `user ACTION ACTION ...': store permitted actions in a bitmask. */
       m = 0;
       for (;;) {
        if ((t = parse_actname(fp, &a)) != 0) break;
@@ -212,28 +255,40 @@ static int parse_action(FILE *fp, struct action *act)
       act->act = A_USER;
       act->u.user = m;
       break;
+
     case A_TOKEN:
     case A_NAME:
     case A_DENY:
     case A_HIDE:
+      /* Dull actions which don't accept parameters. */
       act->act = a;
       break;
+
     case A_LIE:
+      /* `lie NAME': store the string we're to report. */
       if ((t = scan(fp, buf, sizeof(buf))) != 0) return (t);
       act->act = a;
       act->u.lie = xstrdup(buf);
       break;
   }
+
+  /* Make sure we've reached the end of the line. */
   t = scan(fp, buf, sizeof(buf));
   if (t != T_EOF && t != T_EOL) {
     free_action(act);
     return (T_ERROR);
   }
+
+  /* Done. */
   return (0);
 }
 
-static int parse_sockpat(FILE *fp, const struct addrops **aop,
-                        struct sockpat *sp)
+/* Parse an address pattern, writing it to AP.  If the pattern has an
+ * identifiable address family, update *AOP to point to its operations table;
+ * if *AOP is already set to something different then report an error.
+ */
+static int parse_addrpat(FILE *fp, const struct addrops **aop,
+                        struct addrpat *ap)
 {
   char buf[64];
   int t;
@@ -241,47 +296,96 @@ static int parse_sockpat(FILE *fp, const struct addrops **aop,
   long n;
   char *delim;
 
+  /* Scan a token for the address pattern. */
   if ((t = scan(fp, buf, sizeof(buf))) != 0) return (t);
-  if (strcmp(buf, "*") == 0)
-    sp->addr.len = 0;
-  else {
-    if (strchr(buf, ':'))
-      ao = &addroptab[ADDR_IPV6];
-    else
-      ao = &addroptab[ADDR_IPV4];
-    if (!*aop) *aop = ao;
-    else if (*aop != ao) return (T_ERROR);
-    delim = strchr(buf, '/');
-    if (delim) *delim++ = 0;
-    if (!inet_pton(ao->af, buf, &sp->addr.addr)) return (T_ERROR);
-    if (!delim) n = ao->len;
-    else n = strtol(delim, 0, 10);
-    if (n < 0 || n > ao->len) return (T_ERROR);
-    sp->addr.len = n;
+
+  /* If this is a wildcard, then leave everything as it is. */
+  if (strcmp(buf, "*") == 0) {
+    ap->len = 0;
+    return (T_OK);
   }
 
+  /* Decide what kind of address this must be.  A bit grim, sorry. */
+  if (strchr(buf, ':'))
+    ao = &addroptab[ADDR_IPV6];
+  else
+    ao = &addroptab[ADDR_IPV4];
+
+  /* Update the caller's idea of the address family in use. */
+  if (!*aop) *aop = ao;
+  else if (*aop != ao) return (T_ERROR);
+
+  /* See whether there's a prefix length.  If so, clobber it. */
+  delim = strchr(buf, '/');
+  if (delim) *delim++ = 0;
+
+  /* Parse the address. */
+  if (!inet_pton(ao->af, buf, &ap->addr)) return (T_ERROR);
+
+  /* Parse the prefix length, or use the maximum one. */
+  if (!delim) n = ao->len;
+  else n = strtol(delim, 0, 10);
+  if (n < 0 || n > ao->len) return (T_ERROR);
+  ap->len = n;
+
+  /* Done. */
+  return (T_OK);
+}
+
+static int parse_portpat(FILE *fp, struct portpat *pp)
+{
+  char buf[64];
+  int t;
+  long n;
+  char *delim;
+
+  /* Parse a token for the pattern. */
   if ((t = scan(fp, buf, sizeof(buf))) != 0) return (T_ERROR);
+
+  /* If this is a wildcard, then we're done. */
   if (strcmp(buf, "*") == 0) {
-    sp->port.lo = 0;
-    sp->port.hi = 65535;
-  } else {
-    delim = strchr(buf, '-');
-    if (delim) *delim++ = 0;
-    n = strtol(buf, 0, 0);
-    if (n < 0 || n > 65535) return (T_ERROR);
-    sp->port.lo = n;
-    if (!delim)
-      sp->port.hi = n;
-    else {
-      n = strtol(delim, 0, 0);
-      if (n < 0 || n > 65535) return (T_ERROR);
-      sp->port.hi = n;
-    }
+    pp->lo = 0;
+    pp->hi = 65535;
+    return (T_OK);
   }
-  return (0);
+
+  /* Find a range delimiter. */
+  delim = strchr(buf, '-');
+  if (delim) *delim++ = 0;
+
+  /* Parse the only or low end of the range. */
+  n = strtol(buf, 0, 0);
+  if (n < 0 || n > 65535) return (T_ERROR);
+  pp->lo = n;
+
+  /* If there's no delimiter, then the high end is equal to the low end;
+   * otherwise, parse the high end.
+   */
+  if (!delim)
+    pp->hi = n;
+  else {
+    n = strtol(delim, 0, 0);
+    if (n < pp->lo || n > 65535) return (T_ERROR);
+    pp->hi = n;
+  }
+
+  /* Done. */
+  return (T_OK);
 }
 
-int parse_policy(FILE *fp, struct policy *p)
+/* Parse a socket pattern, writing it to SP. */
+static int parse_sockpat(FILE *fp, const struct addrops **aop,
+                        struct sockpat *sp)
+{
+  int t;
+
+  if ((t = parse_addrpat(fp, aop, &sp->addr)) != 0) return (t);
+  if ((t = parse_portpat(fp, &sp->port)) != 0) return (T_ERROR);
+  return (T_OK);
+}
+
+/* Parse a policy rule line, writing it to P. */
+static int parse_policy(FILE *fp, struct policy *p)
 {
   int t;
 
@@ -300,6 +404,9 @@ fail:
   return (t);
 }
 
+/* Open a policy file by NAME.  The description WHAT and query Q are used for
+ * formatting error messages for the log.
+ */
 int open_policy_file(struct policy_file *pf, const char *name,
                     const char *what, const struct query *q)
 {
@@ -318,6 +425,9 @@ int open_policy_file(struct policy_file *pf, const char *name,
   return (0);
 }
 
+/* Read a policy rule from the file, storing it in PF->p.  Return one of the
+ * T_* codes.
+ */
 int read_policy_file(struct policy_file *pf)
 {
   int t;
@@ -349,12 +459,18 @@ int read_policy_file(struct policy_file *pf)
   }
 }
 
+/* Close a policy file.  It doesn't matter whether the file was completely
+ * read.
+ */
 void close_policy_file(struct policy_file *pf)
 {
   fclose(pf->fp);
   free_policy(&pf->p);
 }
 
+/* Load a policy file, writing a vector of records into PV.  If the policy
+ * file has errors, then leave PV unchanged and return nonzero.
+ */
 int load_policy_file(const char *file, policy_v *pv)
 {
   struct policy_file pf;
diff --git a/yaid.c b/yaid.c
index fc6ddd4..cde2069 100644 (file)
--- a/yaid.c
+++ b/yaid.c
 
 /*----- Data structures ---------------------------------------------------*/
 
-struct listen {
-  const struct addrops *ao;
-  sel_file f;
-};
-
+/* A write buffer is the gadget which keeps track of our output and writes
+ * portions of it out as and when connections are ready for it.
+ */
 #define WRBUFSZ 1024
 struct writebuf {
-  size_t o, n;
-  sel_file wr;
-  void (*func)(int, void *);
-  void *p;
-  unsigned char buf[WRBUFSZ];
+  size_t o;                            /* Offset of remaining data */
+  size_t n;                            /* Length of remaining data */
+  sel_file wr;                         /* Write selector */
+  void (*func)(int /*err*/, void *);   /* Function to call on completion */
+  void *p;                             /* Context for `func' */
+  unsigned char buf[WRBUFSZ];          /* Output buffer */
 };
 
-struct proxy {
-  struct client *c;
-  int fd;
-  conn cn;
-  selbuf b;
-  struct writebuf wb;
-  char nat[ADDRLEN];
+/* Structure for a listening socket.  There's one of these for each address
+ * family we're looking after.
+ */
+struct listen {
+  const struct addrops *ao;            /* Address family operations */
+  sel_file f;                          /* Watch for incoming connections */
 };
 
+/* The main structure for a client. */
 struct client {
-  selbuf b;
-  int fd;
-  struct query q;
-  struct listen *l;
-  struct writebuf wb;
-  struct proxy *px;
+  int fd;                              /* The connection to the client */
+  selbuf b;                            /* Accumulate lines of input */
+  struct query q;                      /* The clients query and our reply */
+  struct listen *l;                    /* Back to the listener (and ops) */
+  struct writebuf wb;                  /* Write buffer for our reply */
+  struct proxy *px;                    /* Proxy if conn goes via NAT */
+};
+
+/* A proxy connection. */
+struct proxy {
+  int fd;                              /* Connection; -1 if in progress */
+  struct client *c;                    /* Back to the client */
+  conn cn;                             /* Nonblocking connection */
+  selbuf b;                            /* Accumulate the response line */
+  struct writebuf wb;                  /* Write buffer for query */
+  char nat[ADDRLEN];                   /* Server address, as text */
 };
 
 /*----- Static variables --------------------------------------------------*/
 
-static sel_state sel;
+static sel_state sel;                  /* I/O multiplexer state */
 
-static policy_v policy = DA_INIT;
-static fwatch polfw;
+static const struct policy default_policy = POLICY_INIT(A_NAME);
+static policy_v policy = DA_INIT;      /* Vector of global policy rules */
+static fwatch polfw;                   /* Watch policy file for changes */
 
-static unsigned char tokenbuf[4096];
-static size_t tokenptr = sizeof(tokenbuf);
-static int randfd;
+static unsigned char tokenbuf[4096];   /* Random-ish data for tokens */
+static size_t tokenptr = sizeof(tokenbuf); /* Current read position */
+static int randfd;                     /* File descriptor for random data */
 
-/*----- Main code ---------------------------------------------------------*/
+/*----- Ident protocol parsing --------------------------------------------*/
 
-void logmsg(const struct query *q, int prio, const char *msg, ...)
+/* Advance *PP over whitespace characters. */
+static void skipws(const char **pp)
+  { while (isspace((unsigned char )**pp)) (*pp)++; }
+
+/* Copy a token of no more than N bytes starting at *PP into Q, advancing *PP
+ * over it.
+ */
+static int idtoken(const char **pp, char *q, size_t n)
 {
-  va_list ap;
-  dstr d = DSTR_INIT;
+  const char *p = *pp;
 
-  va_start(ap, msg);
-  if (q) {
-    dputsock(&d, q->ao, &q->s[L]);
-    dstr_puts(&d, " <-> ");
-    dputsock(&d, q->ao, &q->s[R]);
-    dstr_puts(&d, ": ");
+  skipws(&p);
+  n--;
+  for (;;) {
+    if (*p == ':' || *p <= 32 || *p >= 127) break;
+    if (!n) return (-1);
+    *q++ = *p++;
+    n--;
   }
-  dstr_vputf(&d, msg, &ap);
-  va_end(ap);
-  fprintf(stderr, "yaid: %s\n", d.buf);
-  dstr_destroy(&d);
+  *q++ = 0;
+  *pp = p;
+  return (0);
+}
+
+/* Read an unsigned decimal number from *PP, and store it in *II.  Check that
+ * it's between MIN and MAX, and advance *PP over it.  Return zero for
+ * success, or nonzero if something goes wrong.
+ */
+static int unum(const char **pp, unsigned *ii, unsigned min, unsigned max)
+{
+  char *q;
+  unsigned long i;
+  int e;
+
+  skipws(pp);
+  if (!isdigit((unsigned char)**pp)) return (-1);
+  e = errno; errno = 0;
+  i = strtoul(*pp, &q, 10);
+  if (errno) return (-1);
+  *pp = q;
+  errno = e;
+  if (i < min || i > max) return (-1);
+  *ii = i;
+  return (0);
 }
 
+/*----- Asynchronous writing ----------------------------------------------*/
+
+/* Callback for actually writing stuff from a `writebuf'. */
 static void write_out(int fd, unsigned mode, void *p)
 {
   ssize_t n;
   struct writebuf *wb = p;
 
+  /* Try to write something. */
   if ((n = write(fd, wb->buf + wb->o, wb->n)) < 0) {
     if (errno == EAGAIN || errno == EWOULDBLOCK) return;
     wb->n = 0;
@@ -106,6 +148,8 @@ static void write_out(int fd, unsigned mode, void *p)
   }
   wb->o += n;
   wb->n -= n;
+
+  /* If there's nothing left then restore the buffer to its empty state. */
   if (!wb->n) {
     wb->o = 0;
     sel_rmfile(&wb->wr);
@@ -113,26 +157,47 @@ static void write_out(int fd, unsigned mode, void *p)
   }
 }
 
+/* Queue N bytes starting at P to be written. */
 static int queue_write(struct writebuf *wb, const void *p, size_t n)
 {
+  /* Maybe there's nothing to actually do. */
   if (!n) return (0);
+
+  /* Make sure it'll fit. */
   if (wb->n - wb->o + n > WRBUFSZ) return (-1);
+
+  /* If there's anything there already, then make sure it's at the start of
+   * the available space.
+   */
   if (wb->o) {
     memmove(wb->buf, wb->buf + wb->o, wb->n);
     wb->o = 0;
   }
-  memcpy(wb->buf + wb->n, p, n);
+
+  /* If there's nothing currently there, then we're not requesting write
+   * notifications, so set that up, and force an initial wake-up.
+   */
   if (!wb->n) {
     sel_addfile(&wb->wr);
     sel_force(&wb->wr);
   }
+
+  /* Copy the new material over. */
+  memcpy(wb->buf + wb->n, p, n);
   wb->n += n;
+
+  /* Done. */
   return (0);
 }
 
+/* Release resources allocated to WB. */
 static void free_writebuf(struct writebuf *wb)
   { if (wb->n) sel_rmfile(&wb->wr); }
 
+/* Initialize a writebuf in *WB, writing to file descriptor FD.  On
+ * completion, call FUNC, passing it P and an error indicator: either 0 for
+ * success or an `errno' value on failure.
+ */
 static void init_writebuf(struct writebuf *wb,
                          int fd, void (*func)(int, void *), void *p)
 {
@@ -142,29 +207,32 @@ static void init_writebuf(struct writebuf *wb,
   wb->n = wb->o = 0;
 }
 
-static void cancel_proxy(struct proxy *px)
-{
-  if (px->fd == -1)
-    conn_kill(&px->cn);
-  else {
-    close(px->fd);
-    selbuf_destroy(&px->b);
-    free_writebuf(&px->wb);
-  }
-  selbuf_enable(&px->c->b);
-  px->c->px = 0;
-  xfree(px);
-}
+/*----- General utilities -------------------------------------------------*/
 
-static void disconnect_client(struct client *c)
+/* Format and log MSG somewhere sensible, at the syslog(3) priority PRIO.
+ * Prefix it with a description of the query Q, if non-null.
+ */
+void logmsg(const struct query *q, int prio, const char *msg, ...)
 {
-  close(c->fd);
-  selbuf_destroy(&c->b);
-  free_writebuf(&c->wb);
-  if (c->px) cancel_proxy(c->px);
-  xfree(c);
+  va_list ap;
+  dstr d = DSTR_INIT;
+
+  va_start(ap, msg);
+  if (q) {
+    dputsock(&d, q->ao, &q->s[L]);
+    dstr_puts(&d, " <-> ");
+    dputsock(&d, q->ao, &q->s[R]);
+    dstr_puts(&d, ": ");
+  }
+  dstr_vputf(&d, msg, &ap);
+  va_end(ap);
+  fprintf(stderr, "yaid: %s\n", d.buf);
+  dstr_destroy(&d);
 }
 
+/* Fix up a socket FD so that it won't bite us.  Returns zero on success, or
+ * nonzero on error.
+ */
 static int fix_up_socket(int fd, const char *what)
 {
   int yes = 1;
@@ -185,6 +253,13 @@ static int fix_up_socket(int fd, const char *what)
   return (0);
 }
 
+/*----- Client output functions -------------------------------------------*/
+
+static void disconnect_client(struct client *c);
+
+/* Notification that output has been written.  If successful, re-enable the
+ * input buffer and prepare for another query.
+ */
 static void done_client_write(int err, void *p)
 {
   struct client *c = p;
@@ -197,6 +272,9 @@ static void done_client_write(int err, void *p)
   }
 }
 
+/* Format the message FMT and queue it to be sent to the client.  Client
+ * input will be disabled until the write completes.
+ */
 static void write_to_client(struct client *c, const char *fmt, ...)
 {
   va_list ap;
@@ -222,6 +300,11 @@ static void write_to_client(struct client *c, const char *fmt, ...)
   }
 }
 
+/* Format a reply to the client, with the form LPORT:RPORT:TY:TOK0[:TOK1].
+ * Typically, TY will be `ERROR' or `USERID'.  In the former case, TOK0 will
+ * be the error token and TOK1 will be null; in the latter case, TOK0 will be
+ * the operating system and TOK1 the user name.
+ */
 static void reply(struct client *c, const char *ty,
                  const char *tok0, const char *tok1)
 {
@@ -230,56 +313,43 @@ static void reply(struct client *c, const char *ty,
                  tok0, tok1 ? ":" : "", tok1 ? tok1 : "");
 }
 
+/* Mapping from error codes to their protocol tokens. */
 const char *const errtok[] = {
 #define DEFTOK(err, tok) tok,
   ERROR(DEFTOK)
 #undef DEFTOK
 };
 
+/* Report an error with code ERR to the client. */
 static void reply_error(struct client *c, unsigned err)
 {
   assert(err < E_LIMIT);
   reply(c, "ERROR", errtok[err], 0);
 }
 
-static void skipws(const char **pp)
-  { while (isspace((unsigned char )**pp)) (*pp)++; }
+/*----- NAT proxy functions -----------------------------------------------*/
 
-static int idtoken(const char **pp, char *q, size_t n)
+/* Cancel the proxy operation PX, closing the connection and releasing
+ * resources.  This is used for both normal and unexpected closures.
+ */
+static void cancel_proxy(struct proxy *px)
 {
-  const char *p = *pp;
-
-  skipws(&p);
-  n--;
-  for (;;) {
-    if (*p == ':' || *p <= 32 || *p >= 127) break;
-    if (!n) return (-1);
-    *q++ = *p++;
-    n--;
+  if (px->fd == -1)
+    conn_kill(&px->cn);
+  else {
+    close(px->fd);
+    selbuf_destroy(&px->b);
+    free_writebuf(&px->wb);
   }
-  *q++ = 0;
-  *pp = p;
-  return (0);
-}
-
-static int unum(const char **pp, unsigned *ii, unsigned min, unsigned max)
-{
-  char *q;
-  unsigned long i;
-  int e;
-
-  skipws(pp);
-  if (!isdigit((unsigned char)**pp)) return (-1);
-  e = errno; errno = 0;
-  i = strtoul(*pp, &q, 10);
-  if (errno) return (-1);
-  *pp = q;
-  errno = e;
-  if (i < min || i > max) return (-1);
-  *ii = i;
-  return (0);
+  selbuf_enable(&px->c->b);
+  px->c->px = 0;
+  xfree(px);
 }
 
+/* Notification that a line (presumably a reply) has been received from the
+ * server.  We should check it, log it, and propagate the answer back.
+ * Whatever happens, this proxy operation is now complete.
+ */
 static void proxy_line(char *line, size_t sz, void *p)
 {
   struct proxy *px = p;
@@ -287,38 +357,55 @@ static void proxy_line(char *line, size_t sz, void *p)
   const char *q = line;
   unsigned lp, rp;
 
+  /* Trim trailing space. */
   while (sz && isspace((unsigned char)line[sz - 1])) sz--;
-  printf("received proxy line from %s: %s\n", px->nat, line);
 
+  /* Parse the port numbers.  These should match the request. */
   if (unum(&q, &lp, 1, 65535)) goto syntax;
   skipws(&q); if (*q != ',') goto syntax; q++;
   if (unum(&q, &rp, 1, 65535)) goto syntax;
   skipws(&q); if (*q != ':') goto syntax; q++;
   if (lp != px->c->q.u.nat.port || rp != px->c->q.s[R].port) goto syntax;
+
+  /* Find out what kind of reply this is. */
   if (idtoken(&q, buf, sizeof(buf))) goto syntax;
   skipws(&q); if (*q != ':') goto syntax; q++;
+
   if (strcmp(buf, "ERROR") == 0) {
+
+    /* Report the error without interpreting it.  It might be meaningful to
+     * the client.
+     */
     skipws(&q);
     logmsg(&px->c->q, LOG_ERR, "proxy error from %s: %s", px->nat, q);
     reply(px->c, "ERROR", q, 0);
+
   } else if (strcmp(buf, "USERID") == 0) {
+
+    /* Parse out the operating system and user name, and pass them on. */
     if (idtoken(&q, buf, sizeof(buf))) goto syntax;
     skipws(&q); if (*q != ':') goto syntax; q++;
     skipws(&q);
     logmsg(&px->c->q, LOG_ERR, "user `%s'; proxy = %s, os = %s",
           q, px->nat, buf);
     reply(px->c, "USERID", buf, q);
+
   } else
     goto syntax;
   goto done;
 
 syntax:
+  /* We didn't understand the message from the client. */
   logmsg(&px->c->q, LOG_ERR, "failed to parse response from %s", px->nat);
   reply_error(px->c, E_UNKNOWN);
 done:
+  /* All finished, no matter what. */
   cancel_proxy(px);
 }
 
+/* Notification that we have written the query to the server.  Await a
+ * response if successful.
+ */
 static void done_proxy_write(int err, void *p)
 {
   struct proxy *px = p;
@@ -333,12 +420,16 @@ static void done_proxy_write(int err, void *p)
   selbuf_enable(&px->b);
 }
 
+/* Notification that the connection to the server is either established or
+ * failed.  In the former case, queue the right query.
+ */
 static void proxy_connected(int fd, void *p)
 {
   struct proxy *px = p;
   char buf[16];
   int n;
 
+  /* If the connection failed then report the problem and give up. */
   if (fd < 0) {
     logmsg(&px->c->q, LOG_ERR,
           "failed to make %s proxy connection to %s: %s",
@@ -348,16 +439,24 @@ static void proxy_connected(int fd, void *p)
     return;
   }
 
+  /* We're now ready to go, so set things up. */
   px->fd = fd;
   selbuf_init(&px->b, &sel, fd, proxy_line, px);
   selbuf_setsize(&px->b, 1024);
   selbuf_disable(&px->b);
   init_writebuf(&px->wb, fd, done_proxy_write, px);
 
+  /* Write the query.  This buffer is large enough because we've already
+   * range-checked the remote the port number and the local one came from the
+   * kernel, which we trust not to do anything stupid.
+   */
   n = sprintf(buf, "%u,%u\r\n", px->c->q.u.nat.port, px->c->q.s[R].port);
   queue_write(&px->wb, buf, n);
 }
 
+/* Proxy the query through to a client machine for which we're providing NAT
+ * disservice.
+ */
 static void proxy_query(struct client *c)
 {
   struct socket s;
@@ -366,9 +465,15 @@ static void proxy_query(struct client *c)
   struct proxy *px;
   int fd;
 
+  /* Allocate the context structure for the NAT. */
   px = xmalloc(sizeof(*px));
+
+  /* We'll use the client host's address in lots of log messages, so we may
+   * as well format it once and use it over and over.
+   */
   inet_ntop(c->q.ao->af, &c->q.u.nat.addr, px->nat, sizeof(px->nat));
 
+  /* Create the socket for the connection. */
   if ((fd = socket(c->q.ao->af, SOCK_STREAM, 0)) < 0) {
     logmsg(&c->q, LOG_ERR, "failed to make %s socket for proxy: %s",
           c->l->ao->name, strerror(errno));
@@ -376,6 +481,13 @@ static void proxy_query(struct client *c)
   }
   if (fix_up_socket(fd, "proxy")) goto err_1;
 
+  /* Set up the connection to the client host.  The connection interface is a
+   * bit broken: if the connection completes immediately, then the callback
+   * function is called synchronously, and that might decide to shut
+   * everything down.  So we must have fully initialized our context before
+   * calling `conn_init', and mustn't touch it again afterwards -- since the
+   * block may have been freed.
+   */
   s = c->q.u.nat;
   s.port = 113;
   c->l->ao->socket_to_sockaddr(&s, &ss, &ssz);
@@ -389,8 +501,10 @@ static void proxy_query(struct client *c)
     goto err_2;
   }
 
+  /* All ready to go. */
   return;
 
+  /* Tidy up after various kinds of failures. */
 err_2:
   selbuf_enable(&c->b);
 err_1:
@@ -400,35 +514,65 @@ err_0:
   reply_error(c, E_UNKNOWN);
 }
 
-static const struct policy default_policy = POLICY_INIT(A_NAME);
+/*----- Client connection functions ---------------------------------------*/
+
+/* Disconnect a client, freeing up any associated resources. */
+static void disconnect_client(struct client *c)
+{
+  close(c->fd);
+  selbuf_destroy(&c->b);
+  free_writebuf(&c->wb);
+  if (c->px) cancel_proxy(c->px);
+  xfree(c);
+}
 
+/* Write a pseudorandom token into the buffer at P, which must have space for
+ * at least TOKENSZ bytes.
+ */
+#define TOKENRANDSZ 8
+#define TOKENSZ ((4*TOKENRANDSZ + 5)/3)
 static void user_token(char *p)
 {
-  static const char tokmap[64] =
-    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.-";
   unsigned a = 0;
   unsigned b = 0;
   int i;
-#define TOKENSZ 8
+  static const char tokmap[64] =
+    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.-";
 
-  if (tokenptr + TOKENSZ >= sizeof(tokenbuf)) {
+  /* If there's not enough pseudorandom stuff lying around, then read more
+   * from the kernel.
+   */
+  if (tokenptr + TOKENRANDSZ >= sizeof(tokenbuf)) {
     if (read(randfd, tokenbuf, sizeof(tokenbuf)) < sizeof(tokenbuf))
       die(1, "unexpected short read or error from `/dev/urandom'");
     tokenptr = 0;
   }
 
-  for (i = 0; i < TOKENSZ; i++) {
+  /* Now encode the bytes using a slightly tweaked base-64 encoding.  Read
+   * bytes into the accumulator and write out characters while there's
+   * enough material.
+   */
+  for (i = 0; i < TOKENRANDSZ; i++) {
     a = (a << 8) | tokenbuf[tokenptr++]; b += 8;
     while (b >= 6) {
       b -= 6;
       *p++ = tokmap[(a >> b) & 0x3f];
     }
   }
+
+  /* If there's anything left in the accumulator then flush it out. */
   if (b)
     *p++ = tokmap[(a << (6 - b)) & 0x3f];
+
+  /* Null-terminate the token. */
   *p++ = 0;
 }
 
+/* Notification that a line has been received from the client.  Parse it,
+ * find out about the connection it's referring to, apply the relevant
+ * policy rules, and produce a response.  This is where almost everything
+ * interesting happens.
+ */
 static void client_line(char *line, size_t len, void *p)
 {
   struct client *c = p;
@@ -441,105 +585,162 @@ static void client_line(char *line, size_t len, void *p)
   char buf[16];
   int i;
 
+  /* If the connection has closed, then tidy stuff away. */
   c->q.s[L].port = c->q.s[R].port = 0;
   if (!line) {
     disconnect_client(c);
     return;
   }
 
+  /* See if the policy file has changed since we last looked.  If so, try to
+   * read the new version.
+   */
   if (fwatch_update(&polfw, "yaid.policy")) {
     logmsg(0, LOG_INFO, "reload master policy file `%s'", "yaid.policy");
     load_policy_file("yaid.policy", &policy);
   }
 
+  /* Read the local and remote port numbers into the query structure. */
   q = line;
   if (unum(&q, &c->q.s[L].port, 1, 65535)) goto bad;
   skipws(&q); if (*q != ',') goto bad; q++;
   if (unum(&q, &c->q.s[R].port, 1, 65535)) goto bad;
   skipws(&q); if (*q) goto bad;
 
+  /* Identify the connection.  Act on the result. */
   identify(&c->q);
   switch (c->q.resp) {
+
     case R_UID:
+      /* We found a user.  Track down the user's password entry, because
+       * we'll want that later.  Most of the processing for this case is
+       * below.
+       */
       if ((pw = getpwuid(c->q.u.uid)) == 0) {
        logmsg(&c->q, LOG_ERR, "no passwd entry for user %d", c->q.u.uid);
        reply_error(c, E_NOUSER);
        return;
       }
       break;
+
     case R_NAT:
+      /* We've acted as a NAT for this connection.  Proxy the query through
+       * to the actal client host.
+       */
       proxy_query(c);
       return;
+
     case R_ERROR:
-      /* Should already be logged. */
+      /* We failed to identify the connection for some reason.  We should
+       * already have logged an error, so there's not much to do here.
+       */
       reply_error(c, c->q.u.error);
       return;
+
     default:
+      /* Something happened that we don't understand. */
       abort();
   }
 
+  /* Search the table of policy rules to find a match. */
   for (i = 0; i < DA_LEN(&policy); i++) {
     pol = &DA(&policy)[i];
     if (!match_policy(pol, &c->q)) continue;
-    if (pol->act.act != A_USER)
-      goto match;
+
+    /* If this is something simple, then apply the resulting policy rule. */
+    if (pol->act.act != A_USER) goto match;
+
+    /* The global policy has decided to let the user have a say, so we must
+     * parse the user file.
+     */
     DRESET(&d);
     dstr_putf(&d, "%s/.yaid.policy", pw->pw_dir);
     if (open_policy_file(&pf, d.buf, "user policy file", &c->q))
       continue;
     while (!read_policy_file(&pf)) {
+
+      /* Give up after 100 lines.  If the user's policy is that complicated,
+       * something's gone very wrong.  Or there's too much commentary or
+       * something.
+       */
       if (pf.lno > 100) {
        logmsg(&c->q, LOG_ERR, "%s:%d: user policy file too long",
               pf.name, pf.lno);
        break;
       }
+
+      /* If this isn't a match, go around for the next rule. */
       if (!match_policy(&pf.p, &c->q)) continue;
+
+      /* Check that the user is allowed to request this action.  If not, see
+       * if there's a more acceptable action later on.
+       */
       if (!(pol->act.u.user & (1 << pf.p.act.act))) {
        logmsg(&c->q, LOG_ERR,
               "%s:%d: user action forbidden by global policy",
               pf.name, pf.lno);
        continue;
       }
+
+      /* We've found a match, so grab it, close the file, and say we're
+       * done.
+       */
       upol = pf.p; pol = &upol;
       init_policy(&pf.p);
       close_policy_file(&pf);
+      DDESTROY(&d);
       goto match;
     }
     close_policy_file(&pf);
+    DDESTROY(&d);
   }
+
+  /* No match: apply the built-in default policy. */
   pol = &default_policy;
 
 match:
-  DDESTROY(&d);
   switch (pol->act.act) {
+
     case A_NAME:
+      /* Report the actual user's name. */
       logmsg(&c->q, LOG_INFO, "user `%s' (%d)", pw->pw_name, c->q.u.uid);
       reply(c, "USERID", "UNIX", pw->pw_name);
       break;
+
     case A_TOKEN:
+      /* Report an arbitrary token which we can look up in our log file. */
       user_token(buf);
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); token = %s",
             pw->pw_name, c->q.u.uid, buf);
       reply(c, "USERID", "OTHER", buf);
       break;
+
     case A_DENY:
+      /* Deny that there's anyone there at all. */
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); denying",
             pw->pw_name, c->q.u.uid);
       break;
+
     case A_HIDE:
+      /* Report the user as being hidden. */
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); hiding",
             pw->pw_name, c->q.u.uid);
       reply_error(c, E_HIDDEN);
       break;
+
     case A_LIE:
+      /* Tell an egregious lie about who the user is. */
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); lie = `%s'",
             pw->pw_name, c->q.u.uid, pol->act.u.lie);
       reply(c, "USERID", "UNIX", pol->act.u.lie);
       break;
+
     default:
+      /* Something has gone very wrong. */
       abort();
   }
 
+  /* All done. */
   free_policy(&upol);
   return;
 
@@ -548,6 +749,7 @@ bad:
   disconnect_client(c);
 }
 
+/* Notification that a new client has connected.  Prepare to read a query. */
 static void accept_client(int fd, unsigned mode, void *p)
 {
   struct listen *l = p;
@@ -556,6 +758,7 @@ static void accept_client(int fd, unsigned mode, void *p)
   size_t ssz = sizeof(ssr);
   int sk;
 
+  /* Accept the new connection. */
   if ((sk = accept(fd, (struct sockaddr *)&ssr, &ssz)) < 0) {
     if (errno != EAGAIN && errno == EWOULDBLOCK) {
       logmsg(0, LOG_ERR, "failed to accept incoming %s connection: %s",
@@ -565,9 +768,12 @@ static void accept_client(int fd, unsigned mode, void *p)
   }
   if (fix_up_socket(sk, "incoming client")) { close(sk); return; }
 
+  /* Build a client block and fill it in. */
   c = xmalloc(sizeof(*c));
   c->l = l;
   c->q.ao = l->ao;
+
+  /* Collect the local and remote addresses. */
   l->ao->sockaddr_to_addr(&ssr, &c->q.s[R].addr);
   ssz = sizeof(ssl);
   if (getsockname(sk, (struct sockaddr *)&ssl, &ssz)) {
@@ -581,8 +787,7 @@ static void accept_client(int fd, unsigned mode, void *p)
   l->ao->sockaddr_to_addr(&ssl, &c->q.s[L].addr);
   c->q.s[L].port = c->q.s[R].port = 0;
 
-  /* logmsg(&c->q, LOG_INFO, "accepted %s connection", l->ao->name); */
-
+  /* Set stuff up for reading the query and sending responses. */
   selbuf_init(&c->b, &sel, sk, client_line, c);
   selbuf_setsize(&c->b, 1024);
   c->fd = sk;
@@ -590,6 +795,11 @@ static void accept_client(int fd, unsigned mode, void *p)
   init_writebuf(&c->wb, sk, done_client_write, c);
 }
 
+/*----- Main code ---------------------------------------------------------*/
+
+/* Set up a listening socket for the address family described by AO,
+ * listening on PORT.
+ */
 static int make_listening_socket(const struct addrops *ao, int port)
 {
   int fd;
@@ -599,35 +809,48 @@ static int make_listening_socket(const struct addrops *ao, int port)
   struct listen *l;
   size_t ssz;
 
+  /* Make the socket. */
   if ((fd = socket(ao->af, SOCK_STREAM, 0)) < 0) {
     if (errno == EAFNOSUPPORT) return (-1);
     die(1, "failed to create %s listening socket: %s",
        ao->name, strerror(errno));
   }
-  setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
+
+  /* Build the appropriate local address. */
   s.addr = *ao->any;
   s.port = port;
   ao->socket_to_sockaddr(&s, &ss, &ssz);
+
+  /* Perform any initialization specific to the address type. */
   if (ao->init_listen_socket(fd)) {
     die(1, "failed to initialize %s listening socket: %s",
        ao->name, strerror(errno));
   }
+
+  /* Bind to the address. */
+  setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
   if (bind(fd, (struct sockaddr *)&ss, ssz)) {
     die(1, "failed to bind %s listening socket: %s",
        ao->name, strerror(errno));
   }
+
+  /* Avoid unpleasant race conditions. */
   if (fdflags(fd, O_NONBLOCK, O_NONBLOCK, 0, 0)) {
     die(1, "failed to set %s listening socket nonblocking: %s",
        ao->name, strerror(errno));
   }
+
+  /* Prepare to listen. */
   if (listen(fd, 5))
     die(1, "failed to listen for %s: %s", ao->name, strerror(errno));
 
+  /* Make a record of all of this. */
   l = xmalloc(sizeof(*l));
   l->ao = ao;
   sel_initfile(&sel, &l->f, fd, SEL_READ, accept_client, l);
   sel_addfile(&l->f);
 
+  /* Done. */
   return (0);
 }
 
diff --git a/yaid.h b/yaid.h
index 80aa98b..126c201 100644 (file)
--- a/yaid.h
+++ b/yaid.h
 #include <mLib/sel.h>
 #include <mLib/selbuf.h>
 
-/*----- System specifics --------------------------------------------------*/
+/*----- Address family handling -------------------------------------------*/
 
-#define SYS_UNDEF 0
-#define SYS_LINUX 1
-
-#if SYS == SYS_LINUX
-#  include <linux/netlink.h>
-#  include <linux/rtnetlink.h>
-#else
-#  error "Unsupported operating system: sorry.  Patches welcome!"
-#endif
+/* The maximum length of an address formatted as a text string, including the
+ * terminating null byte.
+ */
+#define ADDRLEN 64
 
-/*----- Data structures ---------------------------------------------------*/
+/* A list of address types. */
+#define ADDRTYPES(_)                                                   \
+  _(ipv4, IPV4)                                                                \
+  _(ipv6, IPV6)
 
-#define ADDRLEN 64
+/* Address types for the various families, in the form acceptable to
+ * inet_ntop(3) and inet_pton(3). */
+#define TYPE_IPV4 struct in_addr
+#define TYPE_IPV6 struct in6_addr
 
+/* A union of address types. */
 union addr {
-  struct in_addr ipv4;
-  struct in6_addr ipv6;
+#define UMEMB(ty, TY) TYPE_##TY ty;
+  ADDRTYPES(UMEMB)
+#undef UMEMB
 };
 
+/* A socket holds an address and a port number. */
 struct socket {
-  union addr addr;
-  unsigned port;
+  union addr addr;                     /* The address */
+  unsigned port;                       /* The port, in /host/ byte order */
 };
 
+/* An address pattern consists of an address and a prefix length: the
+ * pattern matches an address if they agree in the first LEN bits.
+ */
 struct addrpat {
-  unsigned len;
-  union addr addr;
+  union addr addr;                     /* The base address */
+  unsigned len;                                /* The prefix length */
 };
 
+/* A port pattern matches a port if the port is within the stated (inclusive)
+ * bounds.
+ */
 struct portpat {
   unsigned lo, hi;
 };
 
+/* A socket pattern consists simply of an address pattern and a port pattern:
+ * it matches a socket componentwise.
+ */
 struct sockpat {
   struct addrpat addr;
   struct portpat port;
 };
 
-#define ADDRTYPES(_)                                                   \
-  _(ipv4, IPV4)                                                                \
-  _(ipv6, IPV6)
-
+/* The table of address-type operations.  Each address family has one of
+ * these, so that most of the program doesn't need to worry about these
+ * details.
+ */
 struct addrops {
-  int af;
-  const char *name;
-  unsigned len;
-  const union addr *any;
-  const struct addrops_sys *sys;
+  int af;                              /* The AF_* constant */
+  const char *name;                    /* Name of the protocol, for logs */
+  unsigned len;                                /* Length of an address, in bits */
+  const union addr *any;               /* A wildcard address */
+  const struct addrops_sys *sys;       /* Pointer to system-specific ops */
+
   int (*addreq)(const union addr *, const union addr *);
+       /* Return nonzero if the two addresses are equal. */
+
   int (*match_addrpat)(const struct addrpat *, const union addr *);
+       /* Return nonzero if the pattern matches the address. */
+
   void (*socket_to_sockaddr)(const struct socket *s, void *, size_t *);
+       /* Convert a socket structure to a `struct sockaddr', and return the
+        * size of the latter.
+        */
+
   void (*sockaddr_to_addr)(const void *, union addr *);
+       /* Extract the address from a `struct sockaddr'. */
+
   int (*init_listen_socket)(int);
+       /* Perform any necessary extra operations on a socket which is going
+        * to be used to listen for incoming connections.
+        */
 };
 
+/* A handy constant for each address family.  These are more useful than the
+ * AF_* constants in that they form a dense sequence.
+ */
 enum {
 #define DEFADDR(ty, TY) ADDR_##TY,
   ADDRTYPES(DEFADDR)
@@ -131,25 +161,53 @@ enum {
   ADDR_LIMIT
 };
 
+/* The table of address operations, indexed by the ADDR_* constants defined
+ * just above.
+ */
 extern const struct addrops addroptab[];
-#define OPS_SYS(ty, TY)                                        \
+
+/* System-specific operations, provided by the system-specific code for its
+ * own purposes.
+ */
+#define OPS_SYS(ty, TY)                                                        \
   extern const struct addrops_sys addrops_sys_##ty;
 ADDRTYPES(OPS_SYS)
 #undef OPS_SYS
 
+/* Answer whether the sockets SA and SB are equal. */
+extern int sockeq(const struct addrops */*ao*/,
+                 const struct socket */*sa*/, const struct socket */*sb*/);
+
+/* Write a textual description of S to the string D. */
+extern void dputsock(dstr */*d*/, const struct addrops */*ao*/,
+                    const struct socket */*s*/);
+
+/*----- Queries and responses ---------------------------------------------*/
+
+/* Constants for describing the `L'ocal and `R'emote ends of a connection. */
 enum { L, R, NDIR };
 
+/* Response types, and the data needed to represent any associated data.  A
+ * U(MEMB, TYPE) constructs a union member; an N means no associated data.
+ */
 #define RESPONSE(_)                                                    \
   _(ERROR, U(error, unsigned))                                         \
   _(UID, U(uid, uid_t))                                                        \
   _(NAT, U(nat, struct socket))
 
+enum {
+#define DEFENUM(what, branch) R_##what,
+  RESPONSE(DEFENUM)
+#undef DEFENUM
+  R_LIMIT
+};
+
+/* Protocol error tokens. */
 #define ERROR(_)                                                       \
   _(INVPORT, "INVALID-PORT")                                           \
   _(NOUSER, "NO-USER")                                                 \
   _(HIDDEN, "HIDDEN-USER")                                             \
   _(UNKNOWN, "UNKNOWN-ERROR")
-extern const char *const errtok[];
 
 enum {
 #define DEFENUM(err, tok) E_##err,
@@ -158,18 +216,16 @@ enum {
   E_LIMIT
 };
 
-enum {
-#define DEFENUM(what, branch) R_##what,
-  RESPONSE(DEFENUM)
-#undef DEFENUM
-  R_LIMIT
-};
+extern const char *const errtok[];
 
+/* The query structure keeps together the parameters to the client's query
+ * and our response to it.
+ */
 struct query {
-  const struct addrops *ao;
-  struct socket s[NDIR];
-  unsigned resp;
-  union {
+  const struct addrops *ao;            /* Address family operations */
+  struct socket s[NDIR];               /* The local and remote ends */
+  unsigned resp;                       /* Our response type */
+  union {                              /* A union of response data */
 #define DEFBRANCH(WHAT, branch) branch
 #define U(memb, ty) ty memb;
 #define N
@@ -180,13 +236,28 @@ struct query {
   } u;
 } query;
 
-enum {
-  T_OK,
-  T_EOL,
-  T_EOF,
-  T_ERROR
-};
+/*----- Common utility functions ------------------------------------------*/
+
+/* Format and log MSG somewhere sensible, at the syslog(3) priority PRIO.
+ * Prefix it with a description of the query Q, if non-null.
+ */
+extern void logmsg(const struct query */*q*/,
+                  int /*prio*/, const char */*msg*/, ...);
 
+/*----- System-specific connection identification code --------------------*/
+
+/* Find out who is responsible for the connection described in the query Q.
+ * Write the answer to Q.  Errors are logged and reported via the query
+ * structure.
+ */
+extern void identify(struct query */*q*/);
+
+/* Initialize the system-specific code. */
+extern void init_sys(void);
+
+/*----- Policy management -------------------------------------------------*/
+
+/* The possible policy actions and their names. */
 #define ACTIONS(_)                                                     \
   _(USER, "user")                                                      \
   _(TOKEN, "token")                                                    \
@@ -202,54 +273,85 @@ enum {
   A_LIMIT
 };
 
+/* A policy action. */
 struct action {
   unsigned act;
   union {
-    unsigned user;
-    char *lie;
+    unsigned user;                     /* Bitmask of permitted actions */
+    char *lie;                         /* The user name to impersonate */
   } u;
 };
 
+/* A policy rule: if the query matches the pattern, then perform the
+ * action.
+ */
 struct policy {
   const struct addrops *ao;
   struct sockpat sp[NDIR];
   struct action act;
 };
-#define POLICY_INIT(a) { 0, { { { 0 } } }, { a } }
+#define POLICY_INIT(a) { .act.act = a }
+DA_DECL(policy_v, struct policy);
 
-struct policy_file {
-  FILE *fp;
-  const struct query *q;
-  const char *name;
-  const char *what;
-  int err;
-  int lno;
-  struct policy p;
-};
+/* Initialize a policy structure.  In this state, it doesn't actually have
+ * any resources allocated (so can be simply discarded) but it's safe to free
+ * (using `free_policy').
+ */
+extern void init_policy(struct policy */*p*/);
 
-DA_DECL(policy_v, struct policy);
+/* Free a policy structure, resetting it to its freshly-initialized state.
+ * This function is idempotent.
+ */
+extern void free_policy(struct policy */*p*/);
+
+/* Print a policy rule to standard output. */
+extern void print_policy(const struct policy */*p*/);
+
+/* Return true if the query matches the patterns in the policy rule. */
+extern int match_policy(const struct policy */*p*/,
+                       const struct query */*q*/);
+
+/*----- Parsing policy files ----------------------------------------------*/
 
-/*----- Functions provided ------------------------------------------------*/
+/* Possible results from a parse. */
+enum {
+  T_OK,                                        /* Successful: results returned */
+  T_EOL,                               /* End-of-line found immediately */
+  T_EOF,                               /* End-of-file found immediately */
+  T_ERROR                              /* Some kind of error occurred */
+};
 
-int sockeq(const struct addrops *ao,
-          const struct socket *sa, const struct socket *sb);
-void dputsock(dstr *d, const struct addrops *ao, const struct socket *s);
+/* A context for parsing a policy file. */
+struct policy_file {
+  FILE *fp;                            /* The file to read from */
+  const struct query *q;               /* A query to use for logging */
+  const char *name;                    /* The name of the file */
+  const char *what;                    /* A description of the file */
+  int err;                             /* Have there been any errors? */
+  int lno;                             /* The current line number */
+  struct policy p;                     /* Parsed policy rule goes here */
+};
 
-void logmsg(const struct query *q, int prio, const char *msg, ...);
+/* Open a policy file by NAME.  The description WHAT and query Q are used for
+ * formatting error messages for the log.
+ */
+extern int open_policy_file(struct policy_file */*pf*/, const char */*name*/,
+                           const char */*what*/, const struct query */*q*/);
 
-void identify(struct query *q);
-void init_sys(void);
+/* Read a policy rule from the file, storing it in PF->p.  Return one of the
+ * T_* codes.
+ */
+extern int read_policy_file(struct policy_file */*pf*/);
+
+/* Close a policy file.  It doesn't matter whether the file was completely
+ * read.
+ */
+extern void close_policy_file(struct policy_file */*pf*/);
 
-void init_policy(struct policy *p);
-void free_policy(struct policy *p);
-void print_policy(const struct policy *p);
-int match_policy(const struct policy *p, const struct query *q);
-int parse_policy(FILE *fp, struct policy *p);
-int open_policy_file(struct policy_file *pf, const char *name,
-                    const char *what, const struct query *q);
-int read_policy_file(struct policy_file *pf);
-void close_policy_file(struct policy_file *pf);
-int load_policy_file(const char *file, policy_v *pv);
+/* Load a policy file, writing a vector of records into PV.  If the policy
+ * file has errors, then leave PV unchanged and return nonzero.
+ */
+extern int load_policy_file(const char */*file*/, policy_v */*pv*/);
 
 /*----- That's all, folks -------------------------------------------------*/