Probable support for first_kex_packet_follows in KEXINIT. Not significantly
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index ac9d47e..0bd89a4 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -550,9 +550,9 @@ struct Packet {
     struct logblank_t *blanks;
 };
 
-static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen,
+static void ssh1_protocol(Ssh ssh, void *vin, int inlen,
                          struct Packet *pktin);
-static void ssh2_protocol(Ssh ssh, unsigned char *in, int inlen,
+static void ssh2_protocol(Ssh ssh, void *vin, int inlen,
                          struct Packet *pktin);
 static void ssh1_protocol_setup(Ssh ssh);
 static void ssh2_protocol_setup(Ssh ssh);
@@ -568,7 +568,7 @@ static unsigned long ssh_pkt_getuint32(struct Packet *pkt);
 static int ssh2_pkt_getbool(struct Packet *pkt);
 static void ssh_pkt_getstring(struct Packet *pkt, char **p, int *length);
 static void ssh2_timer(void *ctx, long now);
-static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
+static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
                             struct Packet *pktin);
 
 struct rdpkt1_state_tag {
@@ -689,7 +689,7 @@ struct ssh_tag {
     int overall_bufsize;
     int throttled_all;
     int v1_stdout_throttling;
-    int v2_outgoing_sequence;
+    unsigned long v2_outgoing_sequence;
 
     int ssh1_rdpkt_crstate;
     int ssh2_rdpkt_crstate;
@@ -711,7 +711,7 @@ struct ssh_tag {
     /* ssh1 and ssh2 use this for different things, but both use it */
     int protocol_initial_phase_done;
 
-    void (*protocol) (Ssh ssh, unsigned char *in, int inlen,
+    void (*protocol) (Ssh ssh, void *vin, int inlen,
                      struct Packet *pkt);
     struct Packet *(*s_rdpkt) (Ssh ssh, unsigned char **data, int *datalen);
 
@@ -4279,7 +4279,7 @@ static void ssh1_msg_channel_data(Ssh ssh, struct Packet *pktin)
     /* Data sent down one of our channels. */
     int i = ssh_pkt_getuint32(pktin);
     char *p;
-    unsigned int len;
+    int len;
     struct ssh_channel *c;
 
     ssh_pkt_getstring(pktin, &p, &len);
@@ -4599,9 +4599,10 @@ static void ssh1_protocol_setup(Ssh ssh)
     ssh->packet_dispatch[SSH1_MSG_DEBUG] = ssh1_msg_debug;
 }
 
-static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen,
+static void ssh1_protocol(Ssh ssh, void *vin, int inlen,
                          struct Packet *pktin)
 {
+    unsigned char *in=(unsigned char*)vin;
     if (ssh->state == SSH_STATE_CLOSED)
        return;
 
@@ -4652,6 +4653,28 @@ static int in_commasep_string(char *needle, char *haystack, int haylen)
 }
 
 /*
+ * Similar routine for checking whether we have the first string in a list.
+ */
+static int first_in_commasep_string(char *needle, char *haystack, int haylen)
+{
+    int needlen;
+    if (!needle || !haystack)         /* protect against null pointers */
+       return 0;
+    needlen = strlen(needle);
+    /*
+     * Is it at the start of the string?
+     */
+    if (haylen >= needlen &&       /* haystack is long enough */
+       !memcmp(needle, haystack, needlen) &&   /* initial match */
+       (haylen == needlen || haystack[needlen] == ',')
+       /* either , or EOS follows */
+       )
+       return 1;
+    return 0;
+}
+
+
+/*
  * SSH2 key creation method.
  */
 static void ssh2_mkkey(Ssh ssh, Bignum K, unsigned char *H,
@@ -4679,9 +4702,10 @@ static void ssh2_mkkey(Ssh ssh, Bignum K, unsigned char *H,
 /*
  * Handle the SSH2 transport layer.
  */
-static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
+static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
                             struct Packet *pktin)
 {
+    unsigned char *in = (unsigned char *)vin;
     struct do_ssh2_transport_state {
        int nbits, pbits, warn;
        Bignum p, g, e, f, K;
@@ -4918,7 +4942,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
      */
     {
        char *str;
-       int i, j, len;
+       int i, j, len, guessok;
 
        if (pktin->type != SSH2_MSG_KEXINIT) {
            bombout(("expected key exchange packet from server"));
@@ -4957,6 +4981,13 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
                     str ? str : "(null)"));
            crStop(0);
        }
+       /*
+        * Note that the server's guess is considered wrong if it doesn't match
+        * the first algorithm in our list, even if it's still the algorithm
+        * we end up using.
+        */
+       guessok =
+           first_in_commasep_string(s->preferred_kex[0]->name, str, len);
        ssh_pkt_getstring(pktin, &str, &len);    /* host key algorithms */
        for (i = 0; i < lenof(hostkey_algs); i++) {
            if (in_commasep_string(hostkey_algs[i]->name, str, len)) {
@@ -4964,6 +4995,8 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
                break;
            }
        }
+       guessok = guessok &&
+           first_in_commasep_string(hostkey_algs[0]->name, str, len);
        ssh_pkt_getstring(pktin, &str, &len);    /* client->server cipher */
        s->warn = 0;
        for (i = 0; i < s->n_preferred_ciphers; i++) {
@@ -5056,6 +5089,10 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
                break;
            }
        }
+       ssh_pkt_getstring(pktin, &str, &len);  /* client->server language */
+       ssh_pkt_getstring(pktin, &str, &len);  /* server->client language */
+       if (ssh2_pkt_getbool(pktin) && !guessok) /* first_kex_packet_follows */
+           crWaitUntil(pktin);                /* Ignore packet */
     }
 
     /*
@@ -5478,7 +5515,7 @@ static void ssh2_msg_channel_window_adjust(Ssh ssh, struct Packet *pktin)
 static void ssh2_msg_channel_data(Ssh ssh, struct Packet *pktin)
 {
     char *data;
-    unsigned int length;
+    int length;
     unsigned i = ssh_pkt_getuint32(pktin);
     struct ssh_channel *c;
     c = find234(ssh->channels, &i, ssh_channelfind);
@@ -5797,7 +5834,7 @@ static void ssh2_msg_channel_request(Ssh ssh, struct Packet *pktin)
        if (q >= 0 && q+4 <= len) { \
            q = q + 4 + GET_32BIT(p+q); \
            if (q >= 0 && q+4 <= len && \
-                   (q = q + 4 + GET_32BIT(p+q)) && q == len) \
+                   ((q = q + 4 + GET_32BIT(p+q))!= 0) && q == len) \
                result = TRUE; \
        } \
     } while(0)
@@ -6789,6 +6826,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
            } else if (s->method == AUTH_KEYBOARD_INTERACTIVE) {
                if (s->curr_prompt == 0) {
                    s->pktout = ssh2_pkt_init(SSH2_MSG_USERAUTH_INFO_RESPONSE);
+                   s->pktout->forcepad = 256;
                    ssh2_pkt_adduint32(s->pktout, s->num_prompts);
                }
                if (s->need_pw) {      /* only add pw if we just got one! */
@@ -7335,9 +7373,10 @@ static void ssh2_timer(void *ctx, long now)
     }
 }
 
-static void ssh2_protocol(Ssh ssh, unsigned char *in, int inlen,
+static void ssh2_protocol(Ssh ssh, void *vin, int inlen,
                          struct Packet *pktin)
 {
+    unsigned char *in = (unsigned char *)vin;
     if (ssh->state == SSH_STATE_CLOSED)
        return;