Sebastian Kuschel reports that pfd_closing can be called for a socket
[u/mdw/putty] / x11fwd.c
index 50300ad..895d8a9 100644 (file)
--- a/x11fwd.c
+++ b/x11fwd.c
@@ -50,8 +50,25 @@ static int xdmseen_cmp(void *a, void *b)
            memcmp(sa->clientid, sb->clientid, sizeof(sa->clientid));
 }
 
-struct X11Display *x11_setup_display(char *display, int authtype,
-                                    const Config *cfg)
+/* Do-nothing "plug" implementation, used by x11_setup_display() when it
+ * creates a trial connection (and then immediately closes it).
+ * XXX: bit out of place here, could in principle live in a platform-
+ *      independent network.c or something */
+static void dummy_plug_log(Plug p, int type, SockAddr addr, int port,
+                          const char *error_msg, int error_code) { }
+static int dummy_plug_closing
+     (Plug p, const char *error_msg, int error_code, int calling_back)
+{ return 1; }
+static int dummy_plug_receive(Plug p, int urgent, char *data, int len)
+{ return 1; }
+static void dummy_plug_sent(Plug p, int bufsize) { }
+static int dummy_plug_accepting(Plug p, OSSocket sock) { return 1; }
+static const struct plug_function_table dummy_plug = {
+    dummy_plug_log, dummy_plug_closing, dummy_plug_receive,
+    dummy_plug_sent, dummy_plug_accepting
+};
+
+struct X11Display *x11_setup_display(char *display, int authtype, Conf *conf)
 {
     struct X11Display *disp = snew(struct X11Display);
     char *localcopy;
@@ -87,6 +104,7 @@ struct X11Display *x11_setup_display(char *display, int authtype,
        disp->hostname = NULL;
        disp->displaynum = -1;
        disp->screennum = 0;
+       disp->addr = NULL;
     } else {
        char *colon, *dot, *slash;
        char *protocol, *hostname;
@@ -125,13 +143,16 @@ struct X11Display *x11_setup_display(char *display, int authtype,
        if (protocol)
            disp->unixdomain = (!strcmp(protocol, "local") ||
                                !strcmp(protocol, "unix"));
-       else
+       else if (!*hostname || !strcmp(hostname, "unix"))
            disp->unixdomain = platform_uses_x11_unix_by_default;
+       else
+           disp->unixdomain = FALSE;
 
        if (!disp->hostname && !disp->unixdomain)
            disp->hostname = dupstr("localhost");
 
        disp->unixsocketpath = NULL;
+       disp->addr = NULL;
 
        sfree(localcopy);
     }
@@ -139,30 +160,59 @@ struct X11Display *x11_setup_display(char *display, int authtype,
     /*
      * Look up the display hostname, if we need to.
      */
-    if (disp->unixdomain) {
-       disp->addr = platform_get_x11_unix_address(disp->unixsocketpath,
-                                                  disp->displaynum);
-       if (disp->unixsocketpath)
-           disp->realhost = dupstr(disp->unixsocketpath);
-       else
-           disp->realhost = dupprintf("unix:%d", disp->displaynum);
-       disp->port = 0;
-    } else {
+    if (!disp->unixdomain) {
        const char *err;
 
        disp->port = 6000 + disp->displaynum;
        disp->addr = name_lookup(disp->hostname, disp->port,
-                                &disp->realhost, cfg, ADDRTYPE_UNSPEC);
+                                &disp->realhost, conf, ADDRTYPE_UNSPEC);
     
        if ((err = sk_addr_error(disp->addr)) != NULL) {
            sk_addr_free(disp->addr);
            sfree(disp->hostname);
            sfree(disp->unixsocketpath);
+           sfree(disp);
            return NULL;               /* FIXME: report an error */
        }
     }
 
     /*
+     * Try upgrading an IP-style localhost display to a Unix-socket
+     * display (as the standard X connection libraries do).
+     */
+    if (!disp->unixdomain && sk_address_is_local(disp->addr)) {
+       SockAddr ux = platform_get_x11_unix_address(NULL, disp->displaynum);
+       const char *err = sk_addr_error(ux);
+       if (!err) {
+           /* Create trial connection to see if there is a useful Unix-domain
+            * socket */
+           const struct plug_function_table *dummy = &dummy_plug;
+           Socket s = sk_new(sk_addr_dup(ux), 0, 0, 0, 0, 0, (Plug)&dummy);
+           err = sk_socket_error(s);
+           sk_close(s);
+       }
+       if (err) {
+           sk_addr_free(ux);
+       } else {
+           sk_addr_free(disp->addr);
+           disp->unixdomain = TRUE;
+           disp->addr = ux;
+           /* Fill in the rest in a moment */
+       }
+    }
+
+    if (disp->unixdomain) {
+       if (!disp->addr)
+           disp->addr = platform_get_x11_unix_address(disp->unixsocketpath,
+                                                      disp->displaynum);
+       if (disp->unixsocketpath)
+           disp->realhost = dupstr(disp->unixsocketpath);
+       else
+           disp->realhost = dupprintf("unix:%d", disp->displaynum);
+       disp->port = 0;
+    }
+
+    /*
      * Invent the remote authorisation details.
      */
     if (authtype == X11_MIT) {
@@ -199,7 +249,7 @@ struct X11Display *x11_setup_display(char *display, int authtype,
     disp->localauthproto = X11_NO_AUTH;
     disp->localauthdata = NULL;
     disp->localauthdatalen = 0;
-    platform_get_x11_auth(disp, cfg);
+    platform_get_x11_auth(disp, conf);
 
     return disp;
 }
@@ -215,10 +265,10 @@ void x11_free_display(struct X11Display *disp)
     sfree(disp->hostname);
     sfree(disp->unixsocketpath);
     if (disp->localauthdata)
-       memset(disp->localauthdata, 0, disp->localauthdatalen);
+       smemclr(disp->localauthdata, disp->localauthdatalen);
     sfree(disp->localauthdata);
     if (disp->remoteauthdata)
-       memset(disp->remoteauthdata, 0, disp->remoteauthdatalen);
+       smemclr(disp->remoteauthdata, disp->remoteauthdatalen);
     sfree(disp->remoteauthdata);
     sfree(disp->remoteauthprotoname);
     sfree(disp->remoteauthdatastring);
@@ -293,16 +343,43 @@ void x11_get_auth_from_authfile(struct X11Display *disp,
     char *buf, *ptr, *str[4];
     int len[4];
     int family, protocol;
+    int ideal_match = FALSE;
+    char *ourhostname;
+
+    /*
+     * Normally we should look for precisely the details specified in
+     * `disp'. However, there's an oddity when the display is local:
+     * displays like "localhost:0" usually have their details stored
+     * in a Unix-domain-socket record (even if there isn't actually a
+     * real Unix-domain socket available, as with OpenSSH's proxy X11
+     * server).
+     *
+     * This is apparently a fudge to get round the meaninglessness of
+     * "localhost" in a shared-home-directory context -- xauth entries
+     * for Unix-domain sockets already disambiguate this by storing
+     * the *local* hostname in the conveniently-blank hostname field,
+     * but IP "localhost" records couldn't do this. So, typically, an
+     * IP "localhost" entry in the auth database isn't present and if
+     * it were it would be ignored.
+     *
+     * However, we don't entirely trust that (say) Windows X servers
+     * won't rely on a straight "localhost" entry, bad idea though
+     * that is; so if we can't find a Unix-domain-socket entry we'll
+     * fall back to an IP-based entry if we can find one.
+     */
+    int localhost = !disp->unixdomain && sk_address_is_local(disp->addr);
 
     authfp = fopen(authfilename, "rb");
     if (!authfp)
        return;
 
+    ourhostname = get_hostname();
+
     /* Records in .Xauthority contain four strings of up to 64K each */
     buf = snewn(65537 * 4, char);
 
-    while (1) {
-       int c, i, j;
+    while (!ideal_match) {
+       int c, i, j, match = FALSE;
        
 #define GET do { c = fgetc(authfp); if (c == EOF) goto done; c = (unsigned char)c; } while (0)
        /* Expect a big-endian 2-byte number giving address family */
@@ -368,41 +445,54 @@ void x11_get_auth_from_authfile(struct X11Display *disp,
            continue;  /* don't recognise this protocol, look for another */
 
        switch (family) {
-         case 0:
+         case 0:   /* IPv4 */
            if (!disp->unixdomain &&
                sk_addrtype(disp->addr) == ADDRTYPE_IPV4) {
                char buf[4];
                sk_addrcopy(disp->addr, buf);
-               if (len[0] == 4 && !memcmp(str[0], buf, 4))
-                   goto found;
+               if (len[0] == 4 && !memcmp(str[0], buf, 4)) {
+                   match = TRUE;
+                   /* If this is a "localhost" entry, note it down
+                    * but carry on looking for a Unix-domain entry. */
+                   ideal_match = !localhost;
+               }
            }
            break;
-         case 6:
+         case 6:   /* IPv6 */
            if (!disp->unixdomain &&
                sk_addrtype(disp->addr) == ADDRTYPE_IPV6) {
                char buf[16];
                sk_addrcopy(disp->addr, buf);
-               if (len[0] == 16 && !memcmp(str[0], buf, 16))
-                   goto found;
+               if (len[0] == 16 && !memcmp(str[0], buf, 16)) {
+                   match = TRUE;
+                   ideal_match = !localhost;
+               }
            }
            break;
-         case 256:
-           if (disp->unixdomain && !strcmp(disp->hostname, str[0]))
-               goto found;
+         case 256: /* Unix-domain / localhost */
+           if ((disp->unixdomain || localhost)
+               && ourhostname && !strcmp(ourhostname, str[0]))
+               /* A matching Unix-domain socket is always the best
+                * match. */
+               match = ideal_match = TRUE;
            break;
        }
-    }
 
-    found:
-    disp->localauthproto = protocol;
-    disp->localauthdata = snewn(len[3], unsigned char);
-    memcpy(disp->localauthdata, str[3], len[3]);
-    disp->localauthdatalen = len[3];
+       if (match) {
+           /* Current best guess -- may be overridden if !ideal_match */
+           disp->localauthproto = protocol;
+           sfree(disp->localauthdata); /* free previous guess, if any */
+           disp->localauthdata = snewn(len[3], unsigned char);
+           memcpy(disp->localauthdata, str[3], len[3]);
+           disp->localauthdatalen = len[3];
+       }
+    }
 
     done:
     fclose(authfp);
-    memset(buf, 0, 65537 * 4);
+    smemclr(buf, 65537 * 4);
     sfree(buf);
+    sfree(ourhostname);
 }
 
 static void x11_log(Plug p, int type, SockAddr addr, int port,
@@ -416,13 +506,20 @@ static int x11_closing(Plug plug, const char *error_msg, int error_code,
 {
     struct X11Private *pr = (struct X11Private *) plug;
 
-    /*
-     * We have no way to communicate down the forwarded connection,
-     * so if an error occurred on the socket, we just ignore it
-     * and treat it like a proper close.
-     */
-    sshfwd_close(pr->c);
-    x11_close(pr->s);
+    if (error_msg) {
+        /*
+         * Socket error. Slam the connection instantly shut.
+         */
+        sshfwd_unclean_close(pr->c);
+    } else {
+        /*
+         * Ordinary EOF received on socket. Send an EOF on the SSH
+         * channel.
+         */
+        if (pr->c)
+            sshfwd_write_eof(pr->c);
+    }
+
     return 1;
 }
 
@@ -470,8 +567,7 @@ int x11_get_screen_number(char *display)
  * also, fills the SocketsStructure
  */
 extern const char *x11_init(Socket *s, struct X11Display *disp, void *c,
-                           const char *peeraddr, int peerport,
-                           const Config *cfg)
+                           const char *peeraddr, int peerport, Conf *conf)
 {
     static const struct plug_function_table fn_table = {
        x11_log,
@@ -498,7 +594,7 @@ extern const char *x11_init(Socket *s, struct X11Display *disp, void *c,
 
     pr->s = *s = new_connection(sk_addr_dup(disp->addr),
                                disp->realhost, disp->port,
-                               0, 1, 0, 0, (Plug) pr, cfg);
+                               0, 1, 0, 0, (Plug) pr, conf);
     if ((err = sk_socket_error(*s)) != NULL) {
        sfree(pr);
        return err;
@@ -624,7 +720,7 @@ int x11_send(Socket s, char *data, int len)
            int msglen, msgsize;
            unsigned char *reply;
 
-           message = dupprintf("PuTTY X11 proxy: %s", err);
+           message = dupprintf("%s X11 proxy: %s", appname, err);
            msglen = strlen(message);
            reply = snewn(8 + msglen+1 + 4, unsigned char); /* include zero */
            msgsize = (msglen + 3) & ~3;
@@ -635,8 +731,7 @@ int x11_send(Socket s, char *data, int len)
            memset(reply + 8, 0, msgsize);
            memcpy(reply + 8, message, msglen);
            sshfwd_write(pr->c, (char *)reply, 8 + msgsize);
-           sshfwd_close(pr->c);
-           x11_close(s);
+           sshfwd_write_eof(pr->c);
            sfree(reply);
            sfree(message);
            return 0;
@@ -701,3 +796,8 @@ int x11_send(Socket s, char *data, int len)
 
     return sk_write(s, data, len);
 }
+
+void x11_send_eof(Socket s)
+{
+    sk_write_eof(s);
+}