yaid.c: Reorder `proxy_query': don't use `c' after `conn_init'.
[yaid] / yaid.c
diff --git a/yaid.c b/yaid.c
index 385ce38..fc6ddd4 100644 (file)
--- a/yaid.c
+++ b/yaid.c
@@ -31,8 +31,7 @@
 /*----- Data structures ---------------------------------------------------*/
 
 struct listen {
-  int af;
-  const char *proto;
+  const struct addrops *ao;
   sel_file f;
 };
 
@@ -76,48 +75,6 @@ static int randfd;
 
 /*----- Main code ---------------------------------------------------------*/
 
-static void socket_to_sockaddr(int af, const struct socket *s,
-                              struct sockaddr *sa, size_t *ssz)
-{
-  sa->sa_family = af;
-  switch (af) {
-    case AF_INET: {
-      struct sockaddr_in *sin = (struct sockaddr_in *)sa;
-      sin->sin_addr = s->addr.ipv4;
-      sin->sin_port = htons(s->port);
-      *ssz = sizeof(*sin);
-    } break;
-    case AF_INET6: {
-      struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sa;
-      sin6->sin6_addr = s->addr.ipv6;
-      sin6->sin6_port = htons(s->port);
-      sin6->sin6_flowinfo = 0;
-      sin6->sin6_scope_id = 0;
-      *ssz = sizeof(*sin6);
-    } break;
-    default: abort();
-  }
-}
-
-static void sockaddr_to_addr(const struct sockaddr *sa, union addr *a)
-{
-  switch (sa->sa_family) {
-    case AF_INET: a->ipv4 = ((struct sockaddr_in *)sa)->sin_addr; break;
-    case AF_INET6: a->ipv6 = ((struct sockaddr_in6 *)sa)->sin6_addr; break;
-    default: abort();
-  }
-}
-
-static void dputsock(dstr *d, int af, const struct socket *s)
-{
-  char buf[ADDRLEN];
-
-  inet_ntop(af, &s->addr, buf, sizeof(buf));
-  if (!s->port || af != AF_INET6) dstr_puts(d, buf);
-  else { dstr_putc(d, '['); dstr_puts(d, buf); dstr_putc(d, ']'); }
-  if (s->port) dstr_putf(d, ":%d", s->port);
-}
-
 void logmsg(const struct query *q, int prio, const char *msg, ...)
 {
   va_list ap;
@@ -125,9 +82,9 @@ void logmsg(const struct query *q, int prio, const char *msg, ...)
 
   va_start(ap, msg);
   if (q) {
-    dputsock(&d, q->af, &q->s[L]);
+    dputsock(&d, q->ao, &q->s[L]);
     dstr_puts(&d, " <-> ");
-    dputsock(&d, q->af, &q->s[R]);
+    dputsock(&d, q->ao, &q->s[R]);
     dstr_puts(&d, ": ");
   }
   dstr_vputf(&d, msg, &ap);
@@ -208,6 +165,26 @@ static void disconnect_client(struct client *c)
   xfree(c);
 }
 
+static int fix_up_socket(int fd, const char *what)
+{
+  int yes = 1;
+
+  if (fdflags(fd, O_NONBLOCK, O_NONBLOCK, 0, 0)) {
+    logmsg(0, LOG_ERR, "failed to set %s connection nonblocking: %s",
+          what, strerror(errno));
+    return (-1);
+  }
+
+  if (setsockopt(fd, SOL_SOCKET, SO_OOBINLINE, &yes, sizeof(yes))) {
+    logmsg(0, LOG_ERR,
+          "failed to disable `out-of-band' data on %s connection: %s",
+          what, strerror(errno));
+    return (-1);
+  }
+
+  return (0);
+}
+
 static void done_client_write(int err, void *p)
 {
   struct client *c = p;
@@ -245,16 +222,24 @@ static void write_to_client(struct client *c, const char *fmt, ...)
   }
 }
 
-static void reply(struct client *c, const char *ty, const char *msg)
+static void reply(struct client *c, const char *ty,
+                 const char *tok0, const char *tok1)
 {
-  write_to_client(c, "%u,%u:%s:%s\r\n",
-                 c->q.s[L].port, c->q.s[R].port, ty, msg);
+  write_to_client(c, "%u,%u:%s:%s%s%s\r\n",
+                 c->q.s[L].port, c->q.s[R].port, ty,
+                 tok0, tok1 ? ":" : "", tok1 ? tok1 : "");
 }
 
+const char *const errtok[] = {
+#define DEFTOK(err, tok) tok,
+  ERROR(DEFTOK)
+#undef DEFTOK
+};
+
 static void reply_error(struct client *c, unsigned err)
 {
   assert(err < E_LIMIT);
-  reply(c, "ERROR", errtok[err]);
+  reply(c, "ERROR", errtok[err], 0);
 }
 
 static void skipws(const char **pp)
@@ -315,15 +300,14 @@ static void proxy_line(char *line, size_t sz, void *p)
   if (strcmp(buf, "ERROR") == 0) {
     skipws(&q);
     logmsg(&px->c->q, LOG_ERR, "proxy error from %s: %s", px->nat, q);
-    reply(px->c, "ERROR", q);
+    reply(px->c, "ERROR", q, 0);
   } else if (strcmp(buf, "USERID") == 0) {
     if (idtoken(&q, buf, sizeof(buf))) goto syntax;
     skipws(&q); if (*q != ':') goto syntax; q++;
     skipws(&q);
     logmsg(&px->c->q, LOG_ERR, "user `%s'; proxy = %s, os = %s",
           q, px->nat, buf);
-    write_to_client(px->c, "%u,%u:USERID:%s:%s\r\n",
-                   px->c->q.s[L].port, px->c->q.s[R].port, buf, q);
+    reply(px->c, "USERID", buf, q);
   } else
     goto syntax;
   goto done;
@@ -358,7 +342,7 @@ static void proxy_connected(int fd, void *p)
   if (fd < 0) {
     logmsg(&px->c->q, LOG_ERR,
           "failed to make %s proxy connection to %s: %s",
-          px->c->l->proto, px->nat, strerror(errno));
+          px->c->l->ao->name, px->nat, strerror(errno));
     reply_error(px->c, E_UNKNOWN);
     cancel_proxy(px);
     return;
@@ -380,38 +364,31 @@ static void proxy_query(struct client *c)
   struct sockaddr_storage ss;
   size_t ssz;
   struct proxy *px;
-  int o;
   int fd;
 
   px = xmalloc(sizeof(*px));
-  inet_ntop(c->q.af, &c->q.u.nat.addr, px->nat, sizeof(px->nat));
+  inet_ntop(c->q.ao->af, &c->q.u.nat.addr, px->nat, sizeof(px->nat));
 
-  if ((fd = socket(c->q.af, SOCK_STREAM, 0)) < 0) {
+  if ((fd = socket(c->q.ao->af, SOCK_STREAM, 0)) < 0) {
     logmsg(&c->q, LOG_ERR, "failed to make %s socket for proxy: %s",
-          c->l->proto, strerror(errno));
+          c->l->ao->name, strerror(errno));
     goto err_0;
   }
-
-  if ((o = fcntl(fd, F_GETFL)) < 0 ||
-      fcntl(fd, F_SETFL, o | O_NONBLOCK)) {
-    logmsg(&c->q, LOG_ERR, "failed to set %s proxy socket nonblocking: %s",
-          c->l->proto, strerror(errno));
-    goto err_1;
-  }
+  if (fix_up_socket(fd, "proxy")) goto err_1;
 
   s = c->q.u.nat;
   s.port = 113;
-  socket_to_sockaddr(c->q.af, &s, (struct sockaddr *)&ss, &ssz);
+  c->l->ao->socket_to_sockaddr(&s, &ss, &ssz);
   selbuf_disable(&c->b);
+  c->px = px; px->c = c;
+  px->fd = -1;
   if (conn_init(&px->cn, &sel, fd, (struct sockaddr *)&ss, ssz,
                proxy_connected, px)) {
     logmsg(&c->q, LOG_ERR, "failed to make %s proxy connection to %s: %s",
-          c->l->proto, px->nat, strerror(errno));
+          c->l->ao->name, px->nat, strerror(errno));
     goto err_2;
   }
 
-  c->px = px; px->c = c;
-  px->fd = -1;
   return;
 
 err_2:
@@ -537,13 +514,13 @@ match:
   switch (pol->act.act) {
     case A_NAME:
       logmsg(&c->q, LOG_INFO, "user `%s' (%d)", pw->pw_name, c->q.u.uid);
-      reply(c, "USERID:UNIX", pw->pw_name);
+      reply(c, "USERID", "UNIX", pw->pw_name);
       break;
     case A_TOKEN:
       user_token(buf);
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); token = %s",
             pw->pw_name, c->q.u.uid, buf);
-      reply(c, "USERID:OTHER", buf);
+      reply(c, "USERID", "OTHER", buf);
       break;
     case A_DENY:
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); denying",
@@ -557,7 +534,7 @@ match:
     case A_LIE:
       logmsg(&c->q, LOG_INFO, "user `%s' (%d); lie = `%s'",
             pw->pw_name, c->q.u.uid, pol->act.u.lie);
-      reply(c, "USERID:UNIX", pol->act.u.lie);
+      reply(c, "USERID", "UNIX", pol->act.u.lie);
       break;
     default:
       abort();
@@ -582,28 +559,29 @@ static void accept_client(int fd, unsigned mode, void *p)
   if ((sk = accept(fd, (struct sockaddr *)&ssr, &ssz)) < 0) {
     if (errno != EAGAIN && errno == EWOULDBLOCK) {
       logmsg(0, LOG_ERR, "failed to accept incoming %s connection: %s",
-            l->proto, strerror(errno));
+            l->ao->name, strerror(errno));
     }
     return;
   }
+  if (fix_up_socket(sk, "incoming client")) { close(sk); return; }
 
   c = xmalloc(sizeof(*c));
   c->l = l;
-  c->q.af = l->af;
-  sockaddr_to_addr((struct sockaddr *)&ssr, &c->q.s[R].addr);
+  c->q.ao = l->ao;
+  l->ao->sockaddr_to_addr(&ssr, &c->q.s[R].addr);
   ssz = sizeof(ssl);
   if (getsockname(sk, (struct sockaddr *)&ssl, &ssz)) {
     logmsg(0, LOG_ERR,
           "failed to read local address for incoming %s connection: %s",
-          l->proto, strerror(errno));
+          l->ao->name, strerror(errno));
     close(sk);
     xfree(c);
     return;
   }
-  sockaddr_to_addr((struct sockaddr *)&ssl, &c->q.s[L].addr);
+  l->ao->sockaddr_to_addr(&ssl, &c->q.s[L].addr);
   c->q.s[L].port = c->q.s[R].port = 0;
 
-  /* logmsg(&c->q, LOG_INFO, "accepted %s connection", l->proto); */
+  /* logmsg(&c->q, LOG_INFO, "accepted %s connection", l->ao->name); */
 
   selbuf_init(&c->b, &sel, sk, client_line, c);
   selbuf_setsize(&c->b, 1024);
@@ -612,67 +590,57 @@ static void accept_client(int fd, unsigned mode, void *p)
   init_writebuf(&c->wb, sk, done_client_write, c);
 }
 
-static int make_listening_socket(int af, int port, const char *proto)
+static int make_listening_socket(const struct addrops *ao, int port)
 {
   int fd;
-  int o;
+  int yes = 1;
+  struct socket s;
   struct sockaddr_storage ss;
   struct listen *l;
   size_t ssz;
 
-  if ((fd = socket(af, SOCK_STREAM, 0)) < 0) {
+  if ((fd = socket(ao->af, SOCK_STREAM, 0)) < 0) {
+    if (errno == EAFNOSUPPORT) return (-1);
     die(1, "failed to create %s listening socket: %s",
-       proto, strerror(errno));
+       ao->name, strerror(errno));
   }
-  o = 1; setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &o, sizeof(o));
-  ss.ss_family = af;
-  switch (af) {
-    case AF_INET: {
-      struct sockaddr_in *sin = (struct sockaddr_in *)&ss;
-      sin->sin_addr.s_addr = INADDR_ANY;
-      sin->sin_port = htons(port);
-      ssz = sizeof(*sin);
-    } break;
-    case AF_INET6: {
-      struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&ss;
-      o = 1; setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &o, sizeof(o));
-      sin6->sin6_family = AF_INET6;
-      sin6->sin6_addr = in6addr_any;
-      sin6->sin6_scope_id = 0;
-      sin6->sin6_flowinfo = 0;
-      ssz = sizeof(*sin6);
-    } break;
-    default:
-      abort();
+  setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
+  s.addr = *ao->any;
+  s.port = port;
+  ao->socket_to_sockaddr(&s, &ss, &ssz);
+  if (ao->init_listen_socket(fd)) {
+    die(1, "failed to initialize %s listening socket: %s",
+       ao->name, strerror(errno));
+  }
+  if (bind(fd, (struct sockaddr *)&ss, ssz)) {
+    die(1, "failed to bind %s listening socket: %s",
+       ao->name, strerror(errno));
   }
-  if (bind(fd, (struct sockaddr *)&ss, ssz))
-    die(1, "failed to bind %s listening socket: %s", proto, strerror(errno));
-  if ((o = fcntl(fd, F_GETFL)) < 0 ||
-      fcntl(fd, F_SETFL, o | O_NONBLOCK)) {
+  if (fdflags(fd, O_NONBLOCK, O_NONBLOCK, 0, 0)) {
     die(1, "failed to set %s listening socket nonblocking: %s",
-       proto, strerror(errno));
+       ao->name, strerror(errno));
   }
   if (listen(fd, 5))
-    die(1, "failed to listen for %s: %s", proto, strerror(errno));
+    die(1, "failed to listen for %s: %s", ao->name, strerror(errno));
 
   l = xmalloc(sizeof(*l));
-  l->af = af;
-  l->proto = proto;
+  l->ao = ao;
   sel_initfile(&sel, &l->f, fd, SEL_READ, accept_client, l);
   sel_addfile(&l->f);
 
-  return (fd);
+  return (0);
 }
 
 int main(int argc, char *argv[])
 {
   int port = 113;
-  char buf[ADDRLEN];
-  union addr a;
+  const struct addrops *ao;
+  int any = 0;
 
   ego(argv[0]);
 
   fwatch_init(&polfw, "yaid.policy");
+  init_sys();
   if (load_policy_file("yaid.policy", &policy))
     exit(1);
   { int i;
@@ -685,14 +653,11 @@ int main(int argc, char *argv[])
        strerror(errno));
   }
 
-  if (get_default_gw(AF_INET, &a))
-    printf("ipv4 gw = %s\n", inet_ntop(AF_INET, &a, buf, sizeof(buf)));
-  if (get_default_gw(AF_INET6, &a))
-    printf("ipv6 gw = %s\n", inet_ntop(AF_INET6, &a, buf, sizeof(buf)));
-
   sel_init(&sel);
-  make_listening_socket(AF_INET, port, "IPv4");
-  make_listening_socket(AF_INET6, port, "IPv6");
+  for (ao = addroptab; ao->name; ao++)
+    if (!make_listening_socket(ao, port)) any = 1;
+  if (!any)
+    die(1, "no IP protocols supported");
 
   for (;;)
     if (sel_select(&sel)) die(1, "select failed: %s", strerror(errno));