Revamp interface to verify_ssh_host_key() and askalg(). Each of them
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index e2e8077..e41e9a9 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -726,10 +726,23 @@ struct ssh_tag {
     Config cfg;
 
     /*
-     * Used to transfer data back from async agent callbacks.
+     * Used to transfer data back from async callbacks.
      */
     void *agent_response;
     int agent_response_len;
+    int user_response;
+
+    /*
+     * The SSH connection can be set as `frozen', meaning we are
+     * not currently accepting incoming data from the network. This
+     * is slightly more serious than setting the _socket_ as
+     * frozen, because we may already have had data passed to us
+     * from the network which we need to delay processing until
+     * after the freeze is lifted, so we also need a bufchain to
+     * store that data.
+     */
+    int frozen;
+    bufchain queued_incoming_data;
 
     /*
      * Dispatch table for packet types that we may have to deal
@@ -2331,6 +2344,49 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
     crFinish(0);
 }
 
+static void ssh_process_incoming_data(Ssh ssh,
+                                     unsigned char **data, int *datalen)
+{
+    struct Packet *pktin = ssh->s_rdpkt(ssh, data, datalen);
+    if (pktin) {
+       ssh->protocol(ssh, NULL, 0, pktin);
+       ssh_free_packet(pktin);
+    }
+}
+
+static void ssh_queue_incoming_data(Ssh ssh,
+                                   unsigned char **data, int *datalen)
+{
+    bufchain_add(&ssh->queued_incoming_data, *data, *datalen);
+    *data += *datalen;
+    *datalen = 0;
+}
+
+static void ssh_process_queued_incoming_data(Ssh ssh)
+{
+    void *vdata;
+    unsigned char *data;
+    int len, origlen;
+
+    while (!ssh->frozen && bufchain_size(&ssh->queued_incoming_data)) {
+       bufchain_prefix(&ssh->queued_incoming_data, &vdata, &len);
+       data = vdata;
+       origlen = len;
+
+       while (!ssh->frozen && len > 0)
+           ssh_process_incoming_data(ssh, &data, &len);
+
+       if (origlen > len)
+           bufchain_consume(&ssh->queued_incoming_data, origlen - len);
+    }
+}
+
+static void ssh_set_frozen(Ssh ssh, int frozen)
+{
+    sk_set_frozen(ssh->s, frozen);
+    ssh->frozen = frozen;
+}
+
 static void ssh_gotdata(Ssh ssh, unsigned char *data, int datalen)
 {
     crBegin(ssh->ssh_gotdata_crstate);
@@ -2360,13 +2416,19 @@ static void ssh_gotdata(Ssh ssh, unsigned char *data, int datalen)
      */
     if (datalen == 0)
        crReturnV;
+
+    /*
+     * Process queued data if there is any.
+     */
+    ssh_process_queued_incoming_data(ssh);
+
     while (1) {
        while (datalen > 0) {
-           struct Packet *pktin = ssh->s_rdpkt(ssh, &data, &datalen);
-           if (pktin) {
-               ssh->protocol(ssh, NULL, 0, pktin);
-               ssh_free_packet(pktin);
-           }
+           if (ssh->frozen)
+               ssh_queue_incoming_data(ssh, &data, &datalen);
+
+           ssh_process_incoming_data(ssh, &data, &datalen);
+
            if (ssh->state == SSH_STATE_CLOSED)
                return;
        }
@@ -2554,9 +2616,9 @@ static void ssh1_throttle(Ssh ssh, int adjust)
     ssh->v1_throttle_count += adjust;
     assert(ssh->v1_throttle_count >= 0);
     if (ssh->v1_throttle_count && !old_count) {
-       sk_set_frozen(ssh->s, 1);
+       ssh_set_frozen(ssh, 1);
     } else if (!ssh->v1_throttle_count && old_count) {
-       sk_set_frozen(ssh->s, 0);
+       ssh_set_frozen(ssh, 0);
     }
 }
 
@@ -2680,6 +2742,24 @@ static void ssh_agent_callback(void *sshv, void *reply, int replylen)
        do_ssh2_authconn(ssh, NULL, -1, NULL);
 }
 
+static void ssh_dialog_callback(void *sshv, int ret)
+{
+    Ssh ssh = (Ssh) sshv;
+
+    ssh->user_response = ret;
+
+    if (ssh->version == 1)
+       do_ssh1_login(ssh, NULL, -1, NULL);
+    else
+       do_ssh2_transport(ssh, NULL, -1, NULL);
+
+    /*
+     * This may have unfrozen the SSH connection, so do a
+     * queued-data run.
+     */
+    ssh_process_queued_incoming_data(ssh);
+}
+
 static void ssh_agentf_callback(void *cv, void *reply, int replylen)
 {
     struct ssh_channel *c = (struct ssh_channel *)cv;
@@ -2741,6 +2821,7 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
        Bignum challenge;
        char *commentp;
        int commentlen;
+        int dlgret;
     };
     crState(do_ssh1_login_state);
 
@@ -2828,10 +2909,30 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
            fatalbox("Out of memory");
        rsastr_fmt(keystr, &hostkey);
        rsa_fingerprint(fingerprint, sizeof(fingerprint), &hostkey);
-       verify_ssh_host_key(ssh->frontend,
-                           ssh->savedhost, ssh->savedport, "rsa", keystr,
-                           fingerprint);
+
+        ssh_set_frozen(ssh, 1);
+       s->dlgret = verify_ssh_host_key(ssh->frontend,
+                                        ssh->savedhost, ssh->savedport,
+                                        "rsa", keystr, fingerprint,
+                                        ssh_dialog_callback, ssh);
        sfree(keystr);
+        if (s->dlgret < 0) {
+            do {
+                crReturn(0);
+                if (pktin) {
+                    bombout(("Unexpected data from server while waiting"
+                             " for user host key response"));
+                    crStop(0);
+                }
+            } while (pktin || inlen > 0);
+            s->dlgret = ssh->user_response;
+        }
+        ssh_set_frozen(ssh, 0);
+
+        if (s->dlgret == 0) {
+            ssh->close_expected = TRUE;
+            ssh_closing((Plug)ssh, NULL, 0, 0);
+        }
     }
 
     for (i = 0; i < 32; i++) {
@@ -2893,9 +2994,25 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
 
        /* Warn about chosen cipher if necessary. */
        if (warn) {
-            sk_set_frozen(ssh->s, 1);
-           askalg(ssh->frontend, "cipher", cipher_string);
-            sk_set_frozen(ssh->s, 0);
+            ssh_set_frozen(ssh, 1);
+           s->dlgret = askalg(ssh->frontend, "cipher", cipher_string,
+                              ssh_dialog_callback, ssh);
+           if (s->dlgret < 0) {
+               do {
+                   crReturn(0);
+                   if (pktin) {
+                       bombout(("Unexpected data from server while waiting"
+                                " for user response"));
+                       crStop(0);
+                   }
+               } while (pktin || inlen > 0);
+               s->dlgret = ssh->user_response;
+           }
+            ssh_set_frozen(ssh, 0);
+           if (s->dlgret == 0) {
+               ssh->close_expected = TRUE;
+               ssh_closing((Plug)ssh, NULL, 0, 0);
+           }
         }
     }
 
@@ -4732,6 +4849,8 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
        const struct ssh_compress *preferred_comp;
        int got_session_id, activated_authconn;
        struct Packet *pktout;
+        int dlgret;
+       int guessok;
     };
     crState(do_ssh2_transport_state);
 
@@ -4945,7 +5064,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
      */
     {
        char *str;
-       int i, j, len, guessok;
+       int i, j, len;
 
        if (pktin->type != SSH2_MSG_KEXINIT) {
            bombout(("expected key exchange packet from server"));
@@ -4971,10 +5090,26 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
            }
            if (ssh->kex) {
                if (s->warn) {
-                    sk_set_frozen(ssh->s, 1);
-                   askalg(ssh->frontend, "key-exchange algorithm",
-                          ssh->kex->name);
-                    sk_set_frozen(ssh->s, 0);
+                   ssh_set_frozen(ssh, 1);
+                   s->dlgret = askalg(ssh->frontend, "key-exchange algorithm",
+                                      ssh->kex->name,
+                                      ssh_dialog_callback, ssh);
+                   if (s->dlgret < 0) {
+                       do {
+                           crReturn(0);
+                           if (pktin) {
+                               bombout(("Unexpected data from server while"
+                                        " waiting for user response"));
+                               crStop(0);
+                           }
+                       } while (pktin || inlen > 0);
+                       s->dlgret = ssh->user_response;
+                   }
+                   ssh_set_frozen(ssh, 0);
+                   if (s->dlgret == 0) {
+                       ssh->close_expected = TRUE;
+                       ssh_closing((Plug)ssh, NULL, 0, 0);
+                   }
                 }
                break;
            }
@@ -4989,7 +5124,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
         * the first algorithm in our list, even if it's still the algorithm
         * we end up using.
         */
-       guessok =
+       s->guessok =
            first_in_commasep_string(s->preferred_kex[0]->name, str, len);
        ssh_pkt_getstring(pktin, &str, &len);    /* host key algorithms */
        for (i = 0; i < lenof(hostkey_algs); i++) {
@@ -4998,7 +5133,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
                break;
            }
        }
-       guessok = guessok &&
+       s->guessok = s->guessok &&
            first_in_commasep_string(hostkey_algs[0]->name, str, len);
        ssh_pkt_getstring(pktin, &str, &len);    /* client->server cipher */
        s->warn = 0;
@@ -5016,10 +5151,27 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
            }
            if (s->cscipher_tobe) {
                if (s->warn) {
-                    sk_set_frozen(ssh->s, 1);
-                   askalg(ssh->frontend, "client-to-server cipher",
-                          s->cscipher_tobe->name);
-                    sk_set_frozen(ssh->s, 0);
+                   ssh_set_frozen(ssh, 1);
+                   s->dlgret = askalg(ssh->frontend,
+                                      "client-to-server cipher",
+                                      s->cscipher_tobe->name,
+                                      ssh_dialog_callback, ssh);
+                   if (s->dlgret < 0) {
+                       do {
+                           crReturn(0);
+                           if (pktin) {
+                               bombout(("Unexpected data from server while"
+                                        " waiting for user response"));
+                               crStop(0);
+                           }
+                       } while (pktin || inlen > 0);
+                       s->dlgret = ssh->user_response;
+                   }
+                   ssh_set_frozen(ssh, 0);
+                   if (s->dlgret == 0) {
+                       ssh->close_expected = TRUE;
+                       ssh_closing((Plug)ssh, NULL, 0, 0);
+                   }
                 }
                break;
            }
@@ -5046,10 +5198,27 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
            }
            if (s->sccipher_tobe) {
                if (s->warn) {
-                    sk_set_frozen(ssh->s, 1);
-                   askalg(ssh->frontend, "server-to-client cipher",
-                          s->sccipher_tobe->name);
-                    sk_set_frozen(ssh->s, 0);
+                   ssh_set_frozen(ssh, 1);
+                   s->dlgret = askalg(ssh->frontend,
+                                      "server-to-client cipher",
+                                      s->sccipher_tobe->name,
+                                      ssh_dialog_callback, ssh);
+                   if (s->dlgret < 0) {
+                       do {
+                           crReturn(0);
+                           if (pktin) {
+                               bombout(("Unexpected data from server while"
+                                        " waiting for user response"));
+                               crStop(0);
+                           }
+                       } while (pktin || inlen > 0);
+                       s->dlgret = ssh->user_response;
+                   }
+                   ssh_set_frozen(ssh, 0);
+                   if (s->dlgret == 0) {
+                       ssh->close_expected = TRUE;
+                       ssh_closing((Plug)ssh, NULL, 0, 0);
+                   }
                 }
                break;
            }
@@ -5094,7 +5263,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
        }
        ssh_pkt_getstring(pktin, &str, &len);  /* client->server language */
        ssh_pkt_getstring(pktin, &str, &len);  /* server->client language */
-       if (ssh2_pkt_getbool(pktin) && !guessok) /* first_kex_packet_follows */
+       if (ssh2_pkt_getbool(pktin) && !s->guessok) /* first_kex_packet_follows */
            crWaitUntil(pktin);                /* Ignore packet */
     }
 
@@ -5218,11 +5387,29 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
      */
     s->keystr = ssh->hostkey->fmtkey(s->hkey);
     s->fingerprint = ssh->hostkey->fingerprint(s->hkey);
-    sk_set_frozen(ssh->s, 1);
-    verify_ssh_host_key(ssh->frontend,
-                       ssh->savedhost, ssh->savedport, ssh->hostkey->keytype,
-                       s->keystr, s->fingerprint);
-    sk_set_frozen(ssh->s, 0);
+    ssh_set_frozen(ssh, 1);
+    s->dlgret = verify_ssh_host_key(ssh->frontend,
+                                    ssh->savedhost, ssh->savedport,
+                                    ssh->hostkey->keytype, s->keystr,
+                                   s->fingerprint,
+                                    ssh_dialog_callback, ssh);
+    if (s->dlgret < 0) {
+        do {
+            crReturn(0);
+            if (pktin) {
+                bombout(("Unexpected data from server while waiting"
+                         " for user host key response"));
+                    crStop(0);
+            }
+        } while (pktin || inlen > 0);
+        s->dlgret = ssh->user_response;
+    }
+    ssh_set_frozen(ssh, 0);
+    if (s->dlgret == 0) {
+        ssh->close_expected = TRUE;
+        ssh_closing((Plug)ssh, NULL, 0, 0);
+        crStop(0);
+    }
     if (!s->got_session_id) {     /* don't bother logging this in rekeys */
        logevent("Host key fingerprint is:");
        logevent(s->fingerprint);
@@ -7477,6 +7664,8 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh->queueing = FALSE;
     ssh->qhead = ssh->qtail = NULL;
     ssh->deferred_rekey_reason = NULL;
+    bufchain_init(&ssh->queued_incoming_data);
+    ssh->frozen = FALSE;
 
     *backend_handle = ssh;
 
@@ -7603,6 +7792,7 @@ static void ssh_free(void *handle)
     expire_timer_context(ssh);
     if (ssh->pinger)
        pinger_free(ssh->pinger);
+    bufchain_clear(&ssh->queued_incoming_data);
     sfree(ssh);
 
     random_unref();