Support for falling back through the list of addresses returned from
[u/mdw/putty] / unix / uxnet.c
index fd86af0..0312094 100644 (file)
@@ -48,6 +48,9 @@ struct Socket_tag {
     int oobinline;
     int pending_error;                /* in case send() returns error */
     int listener;
+    int nodelay, keepalive;            /* for connect()-type sockets */
+    int privport, port;                /* and again */
+    SockAddr addr;
 };
 
 /*
@@ -70,9 +73,11 @@ struct SockAddr_tag {
      */
     int family;
 #ifndef NO_IPV6
-    struct addrinfo *ai;              /* Address IPv6 style. */
+    struct addrinfo *ais;             /* Addresses IPv6 style. */
+    struct addrinfo *ai;              /* steps along the linked list */
 #else
-    unsigned long address;            /* Address IPv4 style. */
+    unsigned long *addresses;         /* Addresses IPv4 style. */
+    int naddresses, curraddr;
 #endif
     char hostname[512];                       /* Store an unresolved host name. */
 };
@@ -120,11 +125,6 @@ void sk_cleanup(void)
     }
 }
 
-const char *error_string(int error)
-{
-    return strerror(error);
-}
-
 SockAddr sk_namelookup(const char *host, char **canonicalname, int address_family)
 {
     SockAddr ret = snew(struct SockAddr_tag);
@@ -134,6 +134,7 @@ SockAddr sk_namelookup(const char *host, char **canonicalname, int address_famil
 #else
     unsigned long a;
     struct hostent *h = NULL;
+    int n;
 #endif
     char realhost[8192];
 
@@ -154,7 +155,8 @@ SockAddr sk_namelookup(const char *host, char **canonicalname, int address_famil
     hints.ai_addr = NULL;
     hints.ai_canonname = NULL;
     hints.ai_next = NULL;
-    err = getaddrinfo(host, NULL, &hints, &ret->ai);
+    err = getaddrinfo(host, NULL, &hints, &ret->ais);
+    ret->ai = ret->ais;
     if (err != 0) {
        ret->error = gai_strerror(err);
        return ret;
@@ -172,22 +174,28 @@ SockAddr sk_namelookup(const char *host, char **canonicalname, int address_famil
         * we don't use gethostbyname as a fallback!)
         */
        if (ret->family == 0) {
-               /*debug(("Resolving \"%s\" with gethostbyname() (IPv4 only)...\n", host)); */
-               if ( (h = gethostbyname(host)) )
-                       ret->family = AF_INET;
+           /*debug(("Resolving \"%s\" with gethostbyname() (IPv4 only)...\n", host)); */
+           if ( (h = gethostbyname(host)) )
+               ret->family = AF_INET;
        }
        if (ret->family == 0) {
-               ret->error = (h_errno == HOST_NOT_FOUND ||
-                   h_errno == NO_DATA ||
-                   h_errno == NO_ADDRESS ? "Host does not exist" :
-                   h_errno == TRY_AGAIN ?
-                   "Temporary name service failure" :
-                   "gethostbyname: unknown error");
-               return ret;
+           ret->error = (h_errno == HOST_NOT_FOUND ||
+                         h_errno == NO_DATA ||
+                         h_errno == NO_ADDRESS ? "Host does not exist" :
+                         h_errno == TRY_AGAIN ?
+                         "Temporary name service failure" :
+                         "gethostbyname: unknown error");
+           return ret;
        }
-       memcpy(&a, h->h_addr, sizeof(a));
        /* This way we are always sure the h->h_name is valid :) */
        strncpy(realhost, h->h_name, sizeof(realhost));
+       for (n = 0; h->h_addr_list[n]; n++);
+       ret->addresses = snewn(n, unsigned long);
+       ret->naddresses = n;
+       for (n = 0; n < ret->naddresses; n++) {
+           memcpy(&a, h->h_addr_list[n], sizeof(a));
+           ret->addresses[n] = ntohl(a);
+       }
     } else {
        /*
         * This must be a numeric IPv4 address because it caused a
@@ -195,8 +203,11 @@ SockAddr sk_namelookup(const char *host, char **canonicalname, int address_famil
         */
        ret->family = AF_INET;
        strncpy(realhost, host, sizeof(realhost));
+       ret->addresses = snew(unsigned long);
+       ret->naddresses = 1;
+       ret->addresses[0] = ntohl(a);
+       ret->curraddr = 0;
     }
-    ret->address = ntohl(a);
 #endif
     realhost[lenof(realhost)-1] = '\0';
     *canonicalname = snewn(1+strlen(realhost), char);
@@ -212,11 +223,32 @@ SockAddr sk_nonamelookup(const char *host)
     strncpy(ret->hostname, host, lenof(ret->hostname));
     ret->hostname[lenof(ret->hostname)-1] = '\0';
 #ifndef NO_IPV6
-    ret->ai = NULL;
+    ret->ais = NULL;
+#else
+    ret->addresses = NULL;
 #endif
     return ret;
 }
 
+static int sk_nextaddr(SockAddr addr)
+{
+#ifndef NO_IPV6
+    if (addr->ai->ai_next) {
+       addr->ai = addr->ai->ai_next;
+       addr->family = addr->ai->ai_family;
+       return TRUE;
+    } else
+       return FALSE;
+#else
+    if (addr->curraddr+1 < addr->naddresses) {
+       addr->curraddr++;
+       return TRUE;
+    } else {
+       return FALSE;
+    }
+#endif    
+}
+
 void sk_getaddr(SockAddr addr, char *buf, int buflen)
 {
 
@@ -233,7 +265,7 @@ void sk_getaddr(SockAddr addr, char *buf, int buflen)
 #else
        struct in_addr a;
        assert(addr->family == AF_INET);
-       a.s_addr = htonl(addr->address);
+       a.s_addr = htonl(addr->addresses[addr->curraddr]);
        strncpy(buf, inet_ntoa(a), buflen);
        buf[buflen-1] = '\0';
 #endif
@@ -263,7 +295,7 @@ int sk_address_is_local(SockAddr addr)
 #else
        struct in_addr a;
        assert(addr->family == AF_INET);
-       a.s_addr = htonl(addr->address);
+       a.s_addr = htonl(addr->addresses[addr->curraddr]);
        return ipv4_is_loopback(a);
 #endif
     }
@@ -294,7 +326,7 @@ void sk_addrcopy(SockAddr addr, char *buf)
     struct in_addr a;
 
     assert(addr->family == AF_INET);
-    a.s_addr = htonl(addr->address);
+    a.s_addr = htonl(addr->addresses[addr->curraddr]);
     memcpy(buf, (char*) &a.s_addr, 4);
 #endif
 }
@@ -303,8 +335,10 @@ void sk_addr_free(SockAddr addr)
 {
 
 #ifndef NO_IPV6
-    if (addr->ai != NULL)
-       freeaddrinfo(addr->ai);
+    if (addr->ais != NULL)
+       freeaddrinfo(addr->ais);
+#else
+    sfree(addr->addresses);
 #endif
     sfree(addr);
 }
@@ -366,11 +400,12 @@ Socket sk_register(OSSocket sockfd, Plug plug)
     ret->pending_error = 0;
     ret->oobpending = FALSE;
     ret->listener = 0;
+    ret->addr = NULL;
 
     ret->s = sockfd;
 
     if (ret->s < 0) {
-       ret->error = error_string(errno);
+       ret->error = strerror(errno);
        return (Socket) ret;
     }
 
@@ -382,8 +417,7 @@ Socket sk_register(OSSocket sockfd, Plug plug)
     return (Socket) ret;
 }
 
-Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
-             int nodelay, int keepalive, Plug plug)
+static int try_connect(Actual_Socket sock)
 {
     int s;
 #ifndef NO_IPV6
@@ -392,53 +426,38 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
     struct sockaddr_in a;
     struct sockaddr_un au;
     const struct sockaddr *sa;
-    int err;
-    Actual_Socket ret;
+    int err = 0;
     short localport;
     int fl, salen;
 
-    /*
-     * Create Socket structure.
-     */
-    ret = snew(struct Socket_tag);
-    ret->fn = &tcp_fn_table;
-    ret->error = NULL;
-    ret->plug = plug;
-    bufchain_init(&ret->output_data);
-    ret->connected = 0;                       /* to start with */
-    ret->writable = 0;                /* to start with */
-    ret->sending_oob = 0;
-    ret->frozen = 0;
-    ret->frozen_readable = 0;
-    ret->localhost_only = 0;          /* unused, but best init anyway */
-    ret->pending_error = 0;
-    ret->oobpending = FALSE;
-    ret->listener = 0;
+    if (sock->s >= 0)
+        close(sock->s);
+
+    plug_log(sock->plug, 0, sock->addr, sock->port, NULL, 0);
 
     /*
      * Open socket.
      */
-    assert(addr->family != AF_UNSPEC);
-    s = socket(addr->family, SOCK_STREAM, 0);
-    ret->s = s;
+    assert(sock->addr->family != AF_UNSPEC);
+    s = socket(sock->addr->family, SOCK_STREAM, 0);
+    sock->s = s;
 
     if (s < 0) {
-       ret->error = error_string(errno);
-       return (Socket) ret;
+       err = errno;
+       goto ret;
     }
 
-    ret->oobinline = oobinline;
-    if (oobinline) {
+    if (sock->oobinline) {
        int b = TRUE;
        setsockopt(s, SOL_SOCKET, SO_OOBINLINE, (void *) &b, sizeof(b));
     }
 
-    if (nodelay) {
+    if (sock->nodelay) {
        int b = TRUE;
        setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (void *) &b, sizeof(b));
     }
 
-    if (keepalive) {
+    if (sock->keepalive) {
        int b = TRUE;
        setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, (void *) &b, sizeof(b));
     }
@@ -446,7 +465,7 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
     /*
      * Bind to local address.
      */
-    if (privport)
+    if (sock->privport)
        localport = 1023;              /* count from 1023 downwards */
     else
        localport = 0;                 /* just use port 0 (ie kernel picks) */
@@ -459,13 +478,13 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
 
     /* We don't try to bind to a local address for UNIX domain sockets.  (Why
      * do we bother doing the bind when localport == 0 anyway?) */
-    if(addr->family != AF_UNIX) {
+    if(sock->addr->family != AF_UNIX) {
        /* Loop round trying to bind */
        while (1) {
            int retcode;
 
 #ifndef NO_IPV6
-           if (addr->family == AF_INET6) {
+           if (sock->addr->family == AF_INET6) {
                /* XXX use getaddrinfo to get a local address? */
                a6.sin6_family = AF_INET6;
                a6.sin6_addr = in6addr_any;
@@ -474,7 +493,7 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
            } else
 #endif
            {
-               assert(addr->family == AF_INET);
+               assert(sock->addr->family == AF_INET);
                a.sin_family = AF_INET;
                a.sin_addr.s_addr = htonl(INADDR_ANY);
                a.sin_port = htons(localport);
@@ -496,45 +515,43 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
              break;                   /* we might have got to the end */
        }
        
-       if (err) {
-           ret->error = error_string(err);
-           return (Socket) ret;
-       }
+       if (err)
+           goto ret;
     }
 
     /*
      * Connect to remote address.
      */
-    switch(addr->family) {
+    switch(sock->addr->family) {
 #ifndef NO_IPV6
       case AF_INET:
        /* XXX would be better to have got getaddrinfo() to fill in the port. */
-       ((struct sockaddr_in *)addr->ai->ai_addr)->sin_port =
-           htons(port);
-       sa = (const struct sockaddr *)addr->ai->ai_addr;
-       salen = addr->ai->ai_addrlen;
+       ((struct sockaddr_in *)sock->addr->ai->ai_addr)->sin_port =
+           htons(sock->port);
+       sa = (const struct sockaddr *)sock->addr->ai->ai_addr;
+       salen = sock->addr->ai->ai_addrlen;
        break;
       case AF_INET6:
-       ((struct sockaddr_in *)addr->ai->ai_addr)->sin_port =
-           htons(port);
-       sa = (const struct sockaddr *)addr->ai->ai_addr;
-       salen = addr->ai->ai_addrlen;
+       ((struct sockaddr_in *)sock->addr->ai->ai_addr)->sin_port =
+           htons(sock->port);
+       sa = (const struct sockaddr *)sock->addr->ai->ai_addr;
+       salen = sock->addr->ai->ai_addrlen;
        break;
 #else
       case AF_INET:
        a.sin_family = AF_INET;
-       a.sin_addr.s_addr = htonl(addr->address);
-       a.sin_port = htons((short) port);
+       a.sin_addr.s_addr = htonl(sock->addr->addresses[sock->addr->curraddr]);
+       a.sin_port = htons((short) sock->port);
        sa = (const struct sockaddr *)&a;
        salen = sizeof a;
        break;
 #endif
       case AF_UNIX:
-       assert(port == 0);      /* to catch confused people */
-       assert(strlen(addr->hostname) < sizeof au.sun_path);
+       assert(sock->port == 0);       /* to catch confused people */
+       assert(strlen(sock->addr->hostname) < sizeof au.sun_path);
        memset(&au, 0, sizeof au);
        au.sun_family = AF_UNIX;
-       strcpy(au.sun_path, addr->hostname);
+       strcpy(au.sun_path, sock->addr->hostname);
        sa = (const struct sockaddr *)&au;
        salen = sizeof au;
        break;
@@ -549,22 +566,65 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
 
     if ((connect(s, sa, salen)) < 0) {
        if ( errno != EINPROGRESS ) {
-           ret->error = error_string(errno);
-           return (Socket) ret;
+           err = errno;
+           goto ret;
        }
     } else {
        /*
         * If we _don't_ get EWOULDBLOCK, the connect has completed
         * and we should set the socket as connected and writable.
         */
-       ret->connected = 1;
-       ret->writable = 1;
+       sock->connected = 1;
+       sock->writable = 1;
     }
 
-    uxsel_tell(ret);
-    add234(sktree, ret);
+    uxsel_tell(sock);
+    add234(sktree, sock);
+
+    ret:
+    if (err)
+       plug_log(sock->plug, 1, sock->addr, sock->port, strerror(err), err);
+    return err;
+}
+
+Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
+             int nodelay, int keepalive, Plug plug)
+{
+    Actual_Socket ret;
+    int err;
+
+    /*
+     * Create Socket structure.
+     */
+    ret = snew(struct Socket_tag);
+    ret->fn = &tcp_fn_table;
+    ret->error = NULL;
+    ret->plug = plug;
+    bufchain_init(&ret->output_data);
+    ret->connected = 0;                       /* to start with */
+    ret->writable = 0;                /* to start with */
+    ret->sending_oob = 0;
+    ret->frozen = 0;
+    ret->frozen_readable = 0;
+    ret->localhost_only = 0;          /* unused, but best init anyway */
+    ret->pending_error = 0;
+    ret->oobpending = FALSE;
+    ret->listener = 0;
+    ret->addr = addr;
+    ret->s = -1;
+    ret->oobinline = oobinline;
+    ret->nodelay = nodelay;
+    ret->keepalive = keepalive;
+    ret->privport = privport;
+    ret->port = port;
+
+    err = 0;
+    do {
+        err = try_connect(ret);
+    } while (err && sk_nextaddr(ret->addr));
 
-    sk_addr_free(addr);
+    if (err)
+        ret->error = strerror(err);
 
     return (Socket) ret;
 }
@@ -600,6 +660,7 @@ Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only, i
     ret->pending_error = 0;
     ret->oobpending = FALSE;
     ret->listener = 1;
+    ret->addr = NULL;
 
     /*
      * Translate address_family from platform-independent constants
@@ -629,7 +690,7 @@ Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only, i
     }
 
     if (s < 0) {
-       ret->error = error_string(errno);
+       ret->error = strerror(errno);
        return (Socket) ret;
     }
 
@@ -700,14 +761,13 @@ Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only, i
     retcode = bind(s, addr, addrlen);
     if (retcode < 0) {
         close(s);
-       ret->error = error_string(errno);
+       ret->error = strerror(errno);
        return (Socket) ret;
     }
 
-
     if (listen(s, SOMAXCONN) < 0) {
         close(s);
-       ret->error = error_string(errno);
+       ret->error = strerror(errno);
        return (Socket) ret;
     }
 
@@ -726,6 +786,8 @@ static void sk_tcp_close(Socket sock)
     uxsel_del(s->s);
     del234(sktree, s);
     close(s->s);
+    if (s->addr)
+        sk_addr_free(s->addr);
     sfree(s);
 }
 
@@ -814,8 +876,8 @@ void try_send(Actual_Socket s)
            } else {
                /* We're inside the Unix frontend here, so we know
                 * that the frontend handle is unnecessary. */
-               logevent(NULL, error_string(err));
-               fatalbox("%s", error_string(err));
+               logevent(NULL, strerror(err));
+               fatalbox("%s", strerror(err));
            }
        } else {
            if (s->sending_oob) {
@@ -912,12 +974,21 @@ static int net_select_result(int fd, int event)
            noise_ultralight(ret);
            if (ret <= 0) {
                const char *str = (ret == 0 ? "Internal networking trouble" :
-                                  error_string(errno));
+                                  strerror(errno));
                /* We're inside the Unix frontend here, so we know
                 * that the frontend handle is unnecessary. */
                logevent(NULL, str);
                fatalbox("%s", str);
            } else {
+                /*
+                 * Receiving actual data on a socket means we can
+                 * stop falling back through the candidate
+                 * addresses to connect to.
+                 */
+                if (s->addr) {
+                    sk_addr_free(s->addr);
+                    s->addr = NULL;
+                }
                return plug_receive(s->plug, 2, buf, ret);
            }
            break;
@@ -988,10 +1059,33 @@ static int net_select_result(int fd, int event)
            }
        }
        if (ret < 0) {
-           return plug_closing(s->plug, error_string(errno), errno, 0);
+            /*
+             * An error at this point _might_ be an error reported
+             * by a non-blocking connect(). So before we return a
+             * panic status to the user, let's just see whether
+             * that's the case.
+             */
+            int err = errno;
+           if (s->addr) {
+               plug_log(s->plug, 1, s->addr, s->port, strerror(err), err);
+               while (s->addr && sk_nextaddr(s->addr)) {
+                   err = try_connect(s);
+               }
+           }
+            if (err != 0)
+                return plug_closing(s->plug, strerror(err), err, 0);
        } else if (0 == ret) {
            return plug_closing(s->plug, NULL, 0, 0);
        } else {
+            /*
+             * Receiving actual data on a socket means we can
+             * stop falling back through the candidate
+             * addresses to connect to.
+             */
+            if (s->addr) {
+                sk_addr_free(s->addr);
+                s->addr = NULL;
+            }
            return plug_receive(s->plug, atmark ? 0 : 1, buf, ret);
        }
        break;
@@ -1048,7 +1142,7 @@ void net_pending_errors(void)
                 * An error has occurred on this socket. Pass it to the
                 * plug.
                 */
-               plug_closing(s->plug, error_string(s->pending_error),
+               plug_closing(s->plug, strerror(s->pending_error),
                             s->pending_error, 0);
                break;
            }
@@ -1140,5 +1234,10 @@ SockAddr platform_get_x11_unix_address(int displaynum, char **canonicalname)
        ret->error = "X11 UNIX name too long";
     else
        *canonicalname = dupstr(ret->hostname);
+#ifndef NO_IPV6
+    ret->ais = NULL;
+#else
+    ret->addresses = NULL;
+#endif
     return ret;
 }