Pageant interface changes. You can now do `pageant -c command' to
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 549a1e5..08ff8d9 100644 (file)
--- a/ssh.c
+++ b/ssh.c
 #define SSH1_AUTH_TIS                             5    /* 0x5 */
 #define SSH1_AUTH_CCARD                           16   /* 0x10 */
 
+#define SSH1_PROTOFLAG_SCREEN_NUMBER              1    /* 0x1 */
+/* Mask for protoflags we will echo back to server if seen */
+#define SSH1_PROTOFLAGS_SUPPORTED                 0    /* 0x1 */
+
 #define SSH2_MSG_DISCONNECT                       1    /* 0x1 */
 #define SSH2_MSG_IGNORE                           2    /* 0x2 */
 #define SSH2_MSG_UNIMPLEMENTED                    3    /* 0x3 */
 #define SSH2_DISCONNECT_HOST_KEY_NOT_VERIFIABLE   9    /* 0x9 */
 #define SSH2_DISCONNECT_CONNECTION_LOST           10   /* 0xa */
 #define SSH2_DISCONNECT_BY_APPLICATION            11   /* 0xb */
+#define SSH2_DISCONNECT_TOO_MANY_CONNECTIONS      12   /* 0xc */
+#define SSH2_DISCONNECT_AUTH_CANCELLED_BY_USER    13   /* 0xd */
+#define SSH2_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE 14 /* 0xe */
+#define SSH2_DISCONNECT_ILLEGAL_USER_NAME         15   /* 0xf */
+
+static const char *const ssh2_disconnect_reasons[] = {
+    NULL,
+    "SSH_DISCONNECT_HOST_NOT_ALLOWED_TO_CONNECT",
+    "SSH_DISCONNECT_PROTOCOL_ERROR",
+    "SSH_DISCONNECT_KEY_EXCHANGE_FAILED",
+    "SSH_DISCONNECT_HOST_AUTHENTICATION_FAILED",
+    "SSH_DISCONNECT_MAC_ERROR",
+    "SSH_DISCONNECT_COMPRESSION_ERROR",
+    "SSH_DISCONNECT_SERVICE_NOT_AVAILABLE",
+    "SSH_DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED",
+    "SSH_DISCONNECT_HOST_KEY_NOT_VERIFIABLE",
+    "SSH_DISCONNECT_CONNECTION_LOST",
+    "SSH_DISCONNECT_BY_APPLICATION",
+    "SSH_DISCONNECT_TOO_MANY_CONNECTIONS",
+    "SSH_DISCONNECT_AUTH_CANCELLED_BY_USER",
+    "SSH_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE",
+    "SSH_DISCONNECT_ILLEGAL_USER_NAME",
+};
 
 #define SSH2_OPEN_ADMINISTRATIVELY_PROHIBITED     1    /* 0x1 */
 #define SSH2_OPEN_CONNECT_FAILED                  2    /* 0x2 */
@@ -256,6 +283,8 @@ static Socket s = NULL;
 
 static unsigned char session_key[32];
 static int ssh1_compressing;
+static int ssh1_remote_protoflags;
+static int ssh1_local_protoflags;
 static int ssh_agentfwd_enabled;
 static int ssh_X11_fwd_enabled;
 static int ssh_remote_bugs;
@@ -347,6 +376,16 @@ static void c_write (char *buf, int len) {
     from_backend(1, buf, len);
 }
 
+static void c_write_untrusted(char *buf, int len) {
+    int i;
+    for (i = 0; i < len; i++) {
+        if (buf[i] == '\n')
+            c_write("\r\n", 2);
+        else if ((buf[i] & 0x60) || (buf[i] == '\r'))
+            c_write(buf+i, 1);
+    }
+}
+
 static void c_write_str (char *buf) {
     c_write(buf, strlen(buf));
 }
@@ -432,11 +471,13 @@ next_packet:
        unsigned char *decompblk;
        int decomplen;
 #if 0
-       int i;
-       debug(("Packet payload pre-decompression:\n"));
-       for (i = -1; i < pktin.length; i++)
-           debug(("  %02x", (unsigned char)pktin.body[i]));
-       debug(("\r\n"));
+        {
+            int i;
+            debug(("Packet payload pre-decompression:\n"));
+            for (i = -1; i < pktin.length; i++)
+                debug(("  %02x", (unsigned char)pktin.body[i]));
+            debug(("\r\n"));
+        }
 #endif
        zlib_decompress_block(pktin.body-1, pktin.length+1,
                              &decompblk, &decomplen);
@@ -453,10 +494,13 @@ next_packet:
        sfree(decompblk);
        pktin.length = decomplen-1;
 #if 0
-       debug(("Packet payload post-decompression:\n"));
-       for (i = -1; i < pktin.length; i++)
-           debug(("  %02x", (unsigned char)pktin.body[i]));
-       debug(("\r\n"));
+        {
+            int i;
+            debug(("Packet payload post-decompression:\n"));
+            for (i = -1; i < pktin.length; i++)
+                debug(("  %02x", (unsigned char)pktin.body[i]));
+            debug(("\r\n"));
+        }
 #endif
     }
 
@@ -489,6 +533,20 @@ next_packet:
        goto next_packet;
     }
 
+    if (pktin.type == SSH1_MSG_DISCONNECT) {
+       /* log reason code in disconnect message */
+       char buf[256];
+       int msglen = GET_32BIT(pktin.body);
+       int nowlen;
+       strcpy(buf, "Remote sent disconnect: ");
+       nowlen = strlen(buf);
+       if (msglen > sizeof(buf)-nowlen-1)
+           msglen = sizeof(buf)-nowlen-1;
+       memcpy(buf+nowlen, pktin.body+4, msglen);
+       buf[nowlen+msglen] = '\0';
+       logevent(buf);
+    }
+
     crFinish(0);
 }
 
@@ -628,6 +686,28 @@ next_packet:
     if (pktin.type == SSH2_MSG_IGNORE || pktin.type == SSH2_MSG_DEBUG)
         goto next_packet;              /* FIXME: print DEBUG message */
 
+    if (pktin.type == SSH2_MSG_DISCONNECT) {
+       /* log reason code in disconnect message */
+       char buf[256];
+       int reason = GET_32BIT(pktin.data+6);
+       int msglen = GET_32BIT(pktin.data+10);
+       int nowlen;
+       if (reason > 0 && reason < lenof(ssh2_disconnect_reasons)) {
+           sprintf(buf, "Received disconnect message (%s)",
+                   ssh2_disconnect_reasons[reason]);
+       } else {
+           sprintf(buf, "Received disconnect message (unknown type %d)", reason);
+       }
+       logevent(buf);
+       strcpy(buf, "Disconnection message text: ");
+       nowlen = strlen(buf);
+       if (msglen > sizeof(buf)-nowlen-1)
+           msglen = sizeof(buf)-nowlen-1;
+       memcpy(buf+nowlen, pktin.data+14, msglen);
+       buf[nowlen+msglen] = '\0';
+       logevent(buf);
+    }
+
     crFinish(0);
 }
 
@@ -907,7 +987,7 @@ static void ssh2_pkt_addstring(char *data) {
 }
 static char *ssh2_mpint_fmt(Bignum b, int *len) {
     unsigned char *p;
-    int i, n = (ssh1_bignum_bitcount(b)+7)/8;
+    int i, n = (bignum_bitcount(b)+7)/8;
     p = smalloc(n + 1);
     if (!p)
         fatalbox("out of memory");
@@ -1107,7 +1187,9 @@ static void ssh_detect_bugs(char *vstring) {
     char *imp;                         /* pointer to implementation part */
     imp = vstring;
     imp += strcspn(imp, "-");
+    if (*imp) imp++;
     imp += strcspn(imp, "-");
+    if (*imp) imp++;
 
     ssh_remote_bugs = 0;
 
@@ -1135,10 +1217,11 @@ static void ssh_detect_bugs(char *vstring) {
 }
 
 static int do_ssh_init(unsigned char c) {
-    static char *vsp;
+    static char vslen;
     static char version[10];
-    static char vstring[80];
-    static char vlog[sizeof(vstring)+20];
+    static char *vstring;
+    static int vstrsize;
+    static char *vlog;
     static int i;
 
     crBegin;
@@ -1158,13 +1241,18 @@ static int do_ssh_init(unsigned char c) {
        crReturn(1);                   /* get another character */
     }
 
+    vstring = smalloc(16);
+    vstrsize = 16;
     strcpy(vstring, "SSH-");
-    vsp = vstring+4;
+    vslen = 4;
     i = 0;
     while (1) {
        crReturn(1);                   /* get another char */
-       if (vsp < vstring+sizeof(vstring)-1)
-           *vsp++ = c;
+       if (vslen >= vstrsize-1) {
+            vstrsize += 16;
+            vstring = srealloc(vstring, vstrsize);
+        }
+        vstring[vslen++] = c;
        if (i >= 0) {
            if (c == '-') {
                version[i] = '\0';
@@ -1179,11 +1267,13 @@ static int do_ssh_init(unsigned char c) {
     ssh_agentfwd_enabled = FALSE;
     rdpkt2_state.incoming_sequence = 0;
 
-    *vsp = 0;
+    vstring[vslen] = 0;
+    vlog = smalloc(20 + vslen);
     sprintf(vlog, "Server version: %s", vstring);
     ssh_detect_bugs(vstring);
     vlog[strcspn(vlog, "\r\n")] = '\0';
     logevent(vlog);
+    sfree(vlog);
 
     /*
      * Server version "1.99" means we can choose whether we use v1
@@ -1193,18 +1283,19 @@ static int do_ssh_init(unsigned char c) {
         /*
          * This is a v2 server. Begin v2 protocol.
          */
-        char *verstring = "SSH-2.0-PuTTY";
+        char verstring[80], vlog[100];
+        sprintf(verstring, "SSH-2.0-%s", sshver);
         SHA_Init(&exhashbase);
         /*
          * Hash our version string and their version string.
          */
         sha_string(&exhashbase, verstring, strlen(verstring));
         sha_string(&exhashbase, vstring, strcspn(vstring, "\r\n"));
-        sprintf(vstring, "%s\n", verstring);
         sprintf(vlog, "We claim version: %s", verstring);
         logevent(vlog);
+        strcat(verstring, "\n");
         logevent("Using SSH protocol version 2");
-        sk_write(s, vstring, strlen(vstring));
+        sk_write(s, verstring, strlen(verstring));
         ssh_protocol = ssh2_protocol;
         ssh_version = 2;
         s_rdpkt = ssh2_rdpkt;
@@ -1212,19 +1303,23 @@ static int do_ssh_init(unsigned char c) {
         /*
          * This is a v1 server. Begin v1 protocol.
          */
-        sprintf(vstring, "SSH-%s-PuTTY\n",
-                (ssh_versioncmp(version, "1.5") <= 0 ? version : "1.5"));
-        sprintf(vlog, "We claim version: %s", vstring);
-        vlog[strcspn(vlog, "\r\n")] = '\0';
+        char verstring[80], vlog[100];
+        sprintf(verstring, "SSH-%s-%s",
+                (ssh_versioncmp(version, "1.5") <= 0 ? version : "1.5"),
+                sshver);
+        sprintf(vlog, "We claim version: %s", verstring);
         logevent(vlog);
+        strcat(verstring, "\n");
         logevent("Using SSH protocol version 1");
-        sk_write(s, vstring, strlen(vstring));
+        sk_write(s, verstring, strlen(verstring));
         ssh_protocol = ssh1_protocol;
         ssh_version = 1;
         s_rdpkt = ssh1_rdpkt;
     }
     ssh_state = SSH_STATE_BEFORE_SIZE;
 
+    sfree(vstring);
+
     crFinish(0);
 }
 
@@ -1270,21 +1365,20 @@ static void ssh_gotdata(unsigned char *data, int datalen)
     crFinishV;
 }
 
-static int ssh_receive(Socket skt, int urgent, char *data, int len) {
-    if (urgent==3) {
+static int ssh_closing (Plug plug, char *error_msg, int error_code, int calling_back) {
+    ssh_state = SSH_STATE_CLOSED;
+    sk_close(s);
+    s = NULL;
+    if (error_msg) {
         /* A socket error has occurred. */
-        ssh_state = SSH_STATE_CLOSED;
-        sk_close(s);
-        s = NULL;
-        connection_fatal(data);
-        return 0;
-    } else if (!len) {
-       /* Connection has closed. */
-       ssh_state = SSH_STATE_CLOSED;
-       sk_close(s);
-       s = NULL;
-       return 0;
+        connection_fatal (error_msg);
+    } else {
+       /* Otherwise, the remote side closed the connection normally. */
     }
+    return 0;
+}
+
+static int ssh_receive(Plug plug, int urgent, char *data, int len) {
     ssh_gotdata (data, len);
     if (ssh_state == SSH_STATE_CLOSED) {
         if (s) {
@@ -1303,6 +1397,11 @@ static int ssh_receive(Socket skt, int urgent, char *data, int len) {
  */
 static char *connect_to_host(char *host, int port, char **realhost)
 {
+    static struct plug_function_table fn_table = {
+       ssh_closing,
+       ssh_receive
+    }, *fn_table_ptr = &fn_table;
+
     SockAddr addr;
     char *err;
 #ifdef FWHACK
@@ -1340,7 +1439,7 @@ static char *connect_to_host(char *host, int port, char **realhost)
     /*
      * Open socket.
      */
-    s = sk_new(addr, port, 0, 1, ssh_receive);
+    s = sk_new(addr, port, 0, 1, &fn_table_ptr);
     if ( (err = sk_socket_error(s)) )
        return err;
 
@@ -1402,9 +1501,13 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
        logevent(logmsg);
     }
 
+    ssh1_remote_protoflags = GET_32BIT(pktin.body+8+i+j);
     supported_ciphers_mask = GET_32BIT(pktin.body+12+i+j);
     supported_auths_mask = GET_32BIT(pktin.body+16+i+j);
 
+    ssh1_local_protoflags = ssh1_remote_protoflags & SSH1_PROTOFLAGS_SUPPORTED;
+    ssh1_local_protoflags |= SSH1_PROTOFLAG_SCREEN_NUMBER;
+
     MD5Init(&md5c);
     MD5Update(&md5c, keystr2, hostkey.bytes);
     MD5Update(&md5c, keystr1, servkey.bytes);
@@ -1466,6 +1569,11 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
     if ((supported_ciphers_mask & (1 << cipher_type)) == 0) {
        c_write_str("Selected cipher not supported, falling back to 3DES\r\n");
        cipher_type = SSH_CIPHER_3DES;
+       if ((supported_ciphers_mask & (1 << cipher_type)) == 0) {
+           bombout(("Server violates SSH 1 protocol by "
+                    "not supporting 3DES encryption"));
+           crReturn(0);
+       }
     }
     switch (cipher_type) {
       case SSH_CIPHER_3DES: logevent("Using 3DES encryption"); break;
@@ -1478,7 +1586,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                 PKT_DATA, cookie, 8,
                 PKT_CHAR, (len*8) >> 8, PKT_CHAR, (len*8) & 0xFF,
                 PKT_DATA, rsabuf, len,
-                PKT_INT, 0,
+                PKT_INT, ssh1_local_protoflags,
                 PKT_END);
 
     logevent("Trying to enable encryption...");
@@ -1604,7 +1712,8 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
             request[4] = SSH1_AGENTC_REQUEST_RSA_IDENTITIES;
             agent_query(request, 5, &r, &responselen);
             response = (unsigned char *)r;
-            if (response) {
+            if (response && responselen >= 5 &&
+                response[4] == SSH1_AGENT_RSA_IDENTITIES_ANSWER) {
                 p = response + 5;
                 nkeys = GET_32BIT(p); p += 4;
                 { char buf[64]; sprintf(buf, "Pageant has %d SSH1 keys", nkeys);
@@ -1644,7 +1753,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                         PUT_32BIT(agentreq, len);
                         q = agentreq + 4;
                         *q++ = SSH1_AGENTC_RSA_CHALLENGE;
-                        PUT_32BIT(q, ssh1_bignum_bitcount(key.modulus));
+                        PUT_32BIT(q, bignum_bitcount(key.modulus));
                         q += 4;
                         q += ssh1_write_bignum(q, key.exponent);
                         q += ssh1_write_bignum(q, key.modulus);
@@ -1978,7 +2087,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
 }
 
 void sshfwd_close(struct ssh_channel *c) {
-    if (c) {
+    if (c && !c->closes) {
         if (ssh_version == 1) {
             send_packet(SSH1_MSG_CHANNEL_CLOSE, PKT_INT, c->remoteid, PKT_END);
         } else {
@@ -2037,10 +2146,16 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
         char proto[20], data[64];
         logevent("Requesting X11 forwarding");
         x11_invent_auth(proto, sizeof(proto), data, sizeof(data));
-        send_packet(SSH1_CMSG_X11_REQUEST_FORWARDING, 
-                   PKT_STR, proto, PKT_STR, data,
-                   PKT_INT, 0,
-                   PKT_END);
+        if (ssh1_local_protoflags & SSH1_PROTOFLAG_SCREEN_NUMBER) {
+            send_packet(SSH1_CMSG_X11_REQUEST_FORWARDING,
+                        PKT_STR, proto, PKT_STR, data,
+                        PKT_INT, 0,
+                        PKT_END);
+        } else {
+            send_packet(SSH1_CMSG_X11_REQUEST_FORWARDING,
+                        PKT_STR, proto, PKT_STR, data,
+                        PKT_END);
+        }
         do { crReturnV; } while (!ispkt);
         if (pktin.type != SSH1_SMSG_SUCCESS && pktin.type != SSH1_SMSG_FAILURE) {
             bombout(("Protocol confusion"));
@@ -2089,8 +2204,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
        zlib_decompress_init();
     }
 
-    if (*cfg.remote_cmd)
-        send_packet(SSH1_CMSG_EXEC_CMD, PKT_STR, cfg.remote_cmd, PKT_END);
+    if (*cfg.remote_cmd_ptr)
+        send_packet(SSH1_CMSG_EXEC_CMD, PKT_STR, cfg.remote_cmd_ptr, PKT_END);
     else
         send_packet(SSH1_CMSG_EXEC_SHELL, PKT_END);
     logevent("Started session");
@@ -2196,7 +2311,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
                 if (c) {
                     int closetype;
                     closetype = (pktin.type == SSH1_MSG_CHANNEL_CLOSE ? 1 : 2);
-                    send_packet(pktin.type, PKT_INT, c->remoteid, PKT_END);
+                    if (!(c->closes & closetype))
+                        send_packet(pktin.type, PKT_INT, c->remoteid, PKT_END);
                    if ((c->closes == 0) && (c->type == CHAN_X11)) {
                        logevent("X11 connection closed");
                        assert(c->u.x11.s != NULL);
@@ -2865,7 +2981,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                      */
                     logevent("No username provided. Abandoning session.");
                     ssh_state = SSH_STATE_CLOSED;
-                    crReturn(1);
+                    crReturnV;
                 }
             } else {
                 c_write_str("login as: ");
@@ -2939,7 +3055,21 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
            if (!gotit)
                crWaitUntilV(ispkt);
            while (pktin.type == SSH2_MSG_USERAUTH_BANNER) {
-               /* FIXME: should support this */
+                char *banner;
+                int size;
+                /*
+                 * Don't show the banner if we're operating in
+                 * non-verbose non-interactive mode. (It's probably
+                 * a script, which means nobody will read the
+                 * banner _anyway_, and moreover the printing of
+                 * the banner will screw up processing on the
+                 * output of (say) plink.)
+                 */
+                if (flags & (FLAG_VERBOSE | FLAG_INTERACTIVE)) {
+                    ssh2_pkt_getstring(&banner, &size);
+                    if (banner)
+                        c_write_untrusted(banner, size);
+                }
                crWaitUntilV(ispkt);
            }
            if (pktin.type == SSH2_MSG_USERAUTH_SUCCESS) {
@@ -2951,6 +3081,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
            if (pktin.type != SSH2_MSG_USERAUTH_FAILURE) {
                bombout(("Strange packet received during authentication: type %d",
                         pktin.type));
+               crReturnV;
            }
 
            gotit = FALSE;
@@ -3031,7 +3162,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                request[4] = SSH2_AGENTC_REQUEST_IDENTITIES;
                agent_query(request, 5, &r, &responselen);
                response = (unsigned char *)r;
-               if (response) {
+               if (response && responselen >= 5 &&
+                    response[4] == SSH2_AGENT_IDENTITIES_ANSWER) {
                    p = response + 5;
                    nkeys = GET_32BIT(p); p += 4;
                    { char buf[64]; sprintf(buf, "Pageant has %d SSH2 keys", nkeys);
@@ -3067,9 +3199,11 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                            continue;
                        }
 
-                       c_write_str("Authenticating with public key \"");
-                       c_write(commentp, commentlen);
-                       c_write_str("\" from agent\r\n");
+                        if (flags & FLAG_VERBOSE) {
+                            c_write_str("Authenticating with public key \"");
+                            c_write(commentp, commentlen);
+                            c_write_str("\" from agent\r\n");
+                        }
 
                        /*
                         * Server is willing to accept the key.
@@ -3452,9 +3586,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
      * Potentially enable agent forwarding.
      */
     if (cfg.agentfwd && agent_exists()) {
-        char proto[20], data[64];
         logevent("Requesting OpenSSH-style agent forwarding");
-        x11_invent_auth(proto, sizeof(proto), data, sizeof(data));
         ssh2_pkt_init(SSH2_MSG_CHANNEL_REQUEST);
         ssh2_pkt_adduint32(mainchan->remoteid);
         ssh2_pkt_addstring("auth-agent-req@openssh.com");
@@ -3537,11 +3669,11 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
     if (cfg.ssh_subsys) {
         ssh2_pkt_addstring("subsystem");
         ssh2_pkt_addbool(1);           /* want reply */
-        ssh2_pkt_addstring(cfg.remote_cmd);
-    } else if (*cfg.remote_cmd) {
+        ssh2_pkt_addstring(cfg.remote_cmd_ptr);
+    } else if (*cfg.remote_cmd_ptr) {
         ssh2_pkt_addstring("exec");
         ssh2_pkt_addbool(1);           /* want reply */
-        ssh2_pkt_addstring(cfg.remote_cmd);
+        ssh2_pkt_addstring(cfg.remote_cmd_ptr);
     } else {
         ssh2_pkt_addstring("shell");
         ssh2_pkt_addbool(1);           /* want reply */
@@ -3728,7 +3860,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                 c = find234(ssh_channels, &i, ssh_channelfind);
                 if (!c)
                     continue;          /* nonexistent channel */
-                mainchan->v2.remwindow += ssh2_pkt_getuint32();
+                c->v2.remwindow += ssh2_pkt_getuint32();
                 try_send = TRUE;
            } else if (pktin.type == SSH2_MSG_CHANNEL_OPEN) {
                 char *type;