Rewrite of Unix sk_newlistener() which should fix any possible
[u/mdw/putty] / unix / uxnet.c
index 1e7d51b..fd86af0 100644 (file)
@@ -69,7 +69,7 @@ struct SockAddr_tag {
      * in this SockAddr structure.
      */
     int family;
-#ifdef IPV6
+#ifndef NO_IPV6
     struct addrinfo *ai;              /* Address IPv6 style. */
 #else
     unsigned long address;            /* Address IPv4 style. */
@@ -125,10 +125,10 @@ const char *error_string(int error)
     return strerror(error);
 }
 
-SockAddr sk_namelookup(const char *host, char **canonicalname)
+SockAddr sk_namelookup(const char *host, char **canonicalname, int address_family)
 {
     SockAddr ret = snew(struct SockAddr_tag);
-#ifdef IPV6
+#ifndef NO_IPV6
     struct addrinfo hints;
     int err;
 #else
@@ -143,16 +143,18 @@ SockAddr sk_namelookup(const char *host, char **canonicalname)
     *realhost = '\0';
     ret->error = NULL;
 
-#ifdef IPV6
+#ifndef NO_IPV6
     hints.ai_flags = AI_CANONNAME;
-    hints.ai_family = AF_UNSPEC;
-    hints.ai_socktype = 0;
+    hints.ai_family = (address_family == ADDRTYPE_IPV4 ? AF_INET :
+                      address_family == ADDRTYPE_IPV6 ? AF_INET6 :
+                      AF_UNSPEC);
+    hints.ai_socktype = SOCK_STREAM;
     hints.ai_protocol = 0;
     hints.ai_addrlen = 0;
     hints.ai_addr = NULL;
     hints.ai_canonname = NULL;
     hints.ai_next = NULL;
-    err = getaddrinfo(host, NULL, NULL, &ret->ai);
+    err = getaddrinfo(host, NULL, &hints, &ret->ai);
     if (err != 0) {
        ret->error = gai_strerror(err);
        return ret;
@@ -164,7 +166,7 @@ SockAddr sk_namelookup(const char *host, char **canonicalname)
     else
        strncat(realhost, host, sizeof(realhost) - 1);
 #else
-    if ((a = inet_addr(host)) == (unsigned long) INADDR_NONE) {
+    if ((a = inet_addr(host)) == (unsigned long)(in_addr_t)(-1)) {
        /*
         * Otherwise use the IPv4-only gethostbyname... (NOTE:
         * we don't use gethostbyname as a fallback!)
@@ -209,6 +211,9 @@ SockAddr sk_nonamelookup(const char *host)
     ret->family = AF_UNSPEC;
     strncpy(ret->hostname, host, lenof(ret->hostname));
     ret->hostname[lenof(ret->hostname)-1] = '\0';
+#ifndef NO_IPV6
+    ret->ai = NULL;
+#endif
     return ret;
 }
 
@@ -219,7 +224,7 @@ void sk_getaddr(SockAddr addr, char *buf, int buflen)
        strncpy(buf, addr->hostname, buflen);
        buf[buflen-1] = '\0';
     } else {
-#ifdef IPV6
+#ifndef NO_IPV6
        if (getnameinfo(addr->ai->ai_addr, addr->ai->ai_addrlen, buf, buflen,
                        NULL, 0, NI_NUMERICHOST) != 0) {
            buf[0] = '\0';
@@ -246,7 +251,7 @@ int sk_address_is_local(SockAddr addr)
     if (addr->family == AF_UNSPEC)
        return 0;                      /* we don't know; assume not */
     else {
-#ifdef IPV6
+#ifndef NO_IPV6
        if (addr->family == AF_INET)
            return ipv4_is_loopback(
                ((struct sockaddr_in *)addr->ai->ai_addr)->sin_addr);
@@ -267,7 +272,7 @@ int sk_address_is_local(SockAddr addr)
 int sk_addrtype(SockAddr addr)
 {
     return (addr->family == AF_INET ? ADDRTYPE_IPV4 :
-#ifdef IPV6
+#ifndef NO_IPV6
            addr->family == AF_INET6 ? ADDRTYPE_IPV6 :
 #endif
            ADDRTYPE_NAME);
@@ -276,7 +281,7 @@ int sk_addrtype(SockAddr addr)
 void sk_addrcopy(SockAddr addr, char *buf)
 {
 
-#ifdef IPV6
+#ifndef NO_IPV6
     if (addr->family == AF_INET)
        memcpy(buf, &((struct sockaddr_in *)addr->ai->ai_addr)->sin_addr,
               sizeof(struct in_addr));
@@ -297,7 +302,7 @@ void sk_addrcopy(SockAddr addr, char *buf)
 void sk_addr_free(SockAddr addr)
 {
 
-#ifdef IPV6
+#ifndef NO_IPV6
     if (addr->ai != NULL)
        freeaddrinfo(addr->ai);
 #endif
@@ -378,10 +383,10 @@ Socket sk_register(OSSocket sockfd, Plug plug)
 }
 
 Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
-             int nodelay, Plug plug)
+             int nodelay, int keepalive, Plug plug)
 {
     int s;
-#ifdef IPV6
+#ifndef NO_IPV6
     struct sockaddr_in6 a6;
 #endif
     struct sockaddr_in a;
@@ -433,6 +438,11 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
        setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (void *) &b, sizeof(b));
     }
 
+    if (keepalive) {
+       int b = TRUE;
+       setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, (void *) &b, sizeof(b));
+    }
+
     /*
      * Bind to local address.
      */
@@ -443,7 +453,7 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
 
     /* BSD IP stacks need sockaddr_in zeroed before filling in */
     memset(&a,'\0',sizeof(struct sockaddr_in));
-#ifdef IPV6
+#ifndef NO_IPV6
     memset(&a6,'\0',sizeof(struct sockaddr_in6));
 #endif
 
@@ -454,7 +464,7 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
        while (1) {
            int retcode;
 
-#ifdef IPV6
+#ifndef NO_IPV6
            if (addr->family == AF_INET6) {
                /* XXX use getaddrinfo to get a local address? */
                a6.sin6_family = AF_INET6;
@@ -496,7 +506,7 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
      * Connect to remote address.
      */
     switch(addr->family) {
-#ifdef IPV6
+#ifndef NO_IPV6
       case AF_INET:
        /* XXX would be better to have got getaddrinfo() to fill in the port. */
        ((struct sockaddr_in *)addr->ai->ai_addr)->sin_port =
@@ -559,18 +569,17 @@ Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
     return (Socket) ret;
 }
 
-Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only)
+Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only, int address_family)
 {
     int s;
-#ifdef IPV6
-#if 0
-    struct sockaddr_in6 a6;
-#endif
+#ifndef NO_IPV6
     struct addrinfo hints, *ai;
     char portstr[6];
+    struct sockaddr_in6 a6;
 #endif
+    struct sockaddr *addr;
+    int addrlen;
     struct sockaddr_in a;
-    int err;
     Actual_Socket ret;
     int retcode;
     int on = 1;
@@ -593,10 +602,31 @@ Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only)
     ret->listener = 1;
 
     /*
+     * Translate address_family from platform-independent constants
+     * into local reality.
+     */
+    address_family = (address_family == ADDRTYPE_IPV4 ? AF_INET :
+                     address_family == ADDRTYPE_IPV6 ? AF_INET6 : AF_UNSPEC);
+
+#ifndef NO_IPV6
+    /* Let's default to IPv6.
+     * If the stack doesn't support IPv6, we will fall back to IPv4. */
+    if (address_family == AF_UNSPEC) address_family = AF_INET6;
+#else
+    /* No other choice, default to IPv4 */
+    if (address_family == AF_UNSPEC)  address_family = AF_INET;
+#endif
+
+    /*
      * Open socket.
      */
-    s = socket(AF_INET, SOCK_STREAM, 0);
-    ret->s = s;
+    s = socket(address_family, SOCK_STREAM, 0);
+
+    /* If the host doesn't support IPv6 try fallback to IPv4. */
+    if (s < 0 && address_family == AF_INET6) {
+       address_family = AF_INET;
+       s = socket(address_family, SOCK_STREAM, 0);
+    }
 
     if (s < 0) {
        ret->error = error_string(errno);
@@ -607,77 +637,70 @@ Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only)
 
     setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char *)&on, sizeof(on));
 
-    /* BSD IP stacks need sockaddr_in zeroed before filling in */
-    memset(&a,'\0',sizeof(struct sockaddr_in));
-#ifdef IPV6
-#if 0
-    memset(&a6,'\0',sizeof(struct sockaddr_in6));
-#endif
-    hints.ai_flags = AI_NUMERICHOST;
-    hints.ai_family = AF_UNSPEC;
-    hints.ai_socktype = 0;
-    hints.ai_protocol = 0;
-    hints.ai_addrlen = 0;
-    hints.ai_addr = NULL;
-    hints.ai_canonname = NULL;
-    hints.ai_next = NULL;
-    sprintf(portstr, "%d", port);
-    if (srcaddr != NULL && getaddrinfo(srcaddr, portstr, &hints, &ai) == 0)
-       retcode = bind(s, ai->ai_addr, ai->ai_addrlen);
-    else
-#if 0
-    {
-       /*
-        * FIXME: Need two listening sockets, in principle, one for v4
-        * and one for v6
-        */
-       if (local_host_only)
-           a6.sin6_addr = in6addr_loopback;
-       else
-           a6.sin6_addr = in6addr_any;
-       a6.sin6_port = htons(port);
-    } else
-#endif
+    retcode = -1;
+    addr = NULL; addrlen = -1;         /* placate optimiser */
+
+    if (srcaddr != NULL) {
+#ifndef NO_IPV6
+        hints.ai_flags = AI_NUMERICHOST;
+        hints.ai_family = address_family;
+        hints.ai_socktype = 0;
+        hints.ai_protocol = 0;
+        hints.ai_addrlen = 0;
+        hints.ai_addr = NULL;
+        hints.ai_canonname = NULL;
+        hints.ai_next = NULL;
+        sprintf(portstr, "%d", port);
+        retcode = getaddrinfo(srcaddr, portstr, &hints, &ai);
+        addr = ai->ai_addr;
+        addrlen = ai->ai_addrlen;
+#else
+        memset(&a,'\0',sizeof(struct sockaddr_in));
+        a.sin_family = AF_INET;
+        a.sin_port = htons(port);
+        a.sin_addr.s_addr = inet_addr(srcaddr);
+        if (a.sin_addr.s_addr != (in_addr_t)(-1)) {
+            /* Override localhost_only with specified listen addr. */
+            ret->localhost_only = ipv4_is_loopback(a.sin_addr);
+            got_addr = 1;
+        }
+        addr = (struct sockaddr *)a;
+        addrlen = sizeof(a);
+        retcode = 0;
 #endif
-    {
-       int got_addr = 0;
-       a.sin_family = AF_INET;
-
-       /*
-        * Bind to source address. First try an explicitly
-        * specified one...
-        */
-       if (srcaddr) {
-           a.sin_addr.s_addr = inet_addr(srcaddr);
-           if (a.sin_addr.s_addr != INADDR_NONE) {
-               /* Override localhost_only with specified listen addr. */
-               ret->localhost_only = ipv4_is_loopback(a.sin_addr);
-               got_addr = 1;
-           }
-       }
+    }
 
-       /*
-        * ... and failing that, go with one of the standard ones.
-        */
-       if (!got_addr) {
+    if (retcode != 0) {
+#ifndef NO_IPV6
+        if (address_family == AF_INET6) {
+            memset(&a6,'\0',sizeof(struct sockaddr_in6));
+            a6.sin6_family = AF_INET6;
+            a6.sin6_port = htons(port);
+            if (local_host_only)
+                a6.sin6_addr = in6addr_loopback;
+            else
+                a6.sin6_addr = in6addr_any;
+            addr = (struct sockaddr *)&a6;
+            addrlen = sizeof(a6);
+        } else
+#endif
+        {
+            memset(&a,'\0',sizeof(struct sockaddr_in));
+            a.sin_family = AF_INET;
+            a.sin_port = htons(port);
            if (local_host_only)
                a.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
            else
                a.sin_addr.s_addr = htonl(INADDR_ANY);
-       }
-
-       a.sin_port = htons((short)port);
-       retcode = bind(s, (struct sockaddr *) &a, sizeof(a));
-    }
-
-    if (retcode >= 0) {
-       err = 0;
-    } else {
-       err = errno;
+            addr = (struct sockaddr *)&a;
+            addrlen = sizeof(a);
+        }
     }
 
-    if (err) {
-       ret->error = error_string(err);
+    retcode = bind(s, addr, addrlen);
+    if (retcode < 0) {
+        close(s);
+       ret->error = error_string(errno);
        return (Socket) ret;
     }
 
@@ -688,6 +711,8 @@ Socket sk_newlistener(char *srcaddr, int port, Plug plug, int local_host_only)
        return (Socket) ret;
     }
 
+    ret->s = s;
+
     uxsel_tell(ret);
     add234(sktree, ret);