Fix a couple of signedness compiler warnings, presumably due to me
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 137e460..5196b1b 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -83,6 +83,9 @@
 #define SSH2_MSG_KEX_DH_GEX_GROUP                 31   /* 0x1f */
 #define SSH2_MSG_KEX_DH_GEX_INIT                  32   /* 0x20 */
 #define SSH2_MSG_KEX_DH_GEX_REPLY                 33   /* 0x21 */
+#define SSH2_MSG_KEXRSA_PUBKEY                    30    /* 0x1e */
+#define SSH2_MSG_KEXRSA_SECRET                    31    /* 0x1f */
+#define SSH2_MSG_KEXRSA_DONE                      32    /* 0x20 */
 #define SSH2_MSG_USERAUTH_REQUEST                 50   /* 0x32 */
 #define SSH2_MSG_USERAUTH_FAILURE                 51   /* 0x33 */
 #define SSH2_MSG_USERAUTH_SUCCESS                 52   /* 0x34 */
  */
 #define SSH2_PKTCTX_DHGROUP          0x0001
 #define SSH2_PKTCTX_DHGEX            0x0002
+#define SSH2_PKTCTX_RSAKEX           0x0004
 #define SSH2_PKTCTX_KEX_MASK         0x000F
 #define SSH2_PKTCTX_PUBLICKEY        0x0010
 #define SSH2_PKTCTX_PASSWORD         0x0020
@@ -339,6 +343,9 @@ static char *ssh2_pkt_type(int pkt_ctx, int type)
     translatec(SSH2_MSG_KEX_DH_GEX_GROUP, SSH2_PKTCTX_DHGEX);
     translatec(SSH2_MSG_KEX_DH_GEX_INIT, SSH2_PKTCTX_DHGEX);
     translatec(SSH2_MSG_KEX_DH_GEX_REPLY, SSH2_PKTCTX_DHGEX);
+    translatec(SSH2_MSG_KEXRSA_PUBKEY, SSH2_PKTCTX_RSAKEX);
+    translatec(SSH2_MSG_KEXRSA_SECRET, SSH2_PKTCTX_RSAKEX);
+    translatec(SSH2_MSG_KEXRSA_DONE, SSH2_PKTCTX_RSAKEX);
     translate(SSH2_MSG_USERAUTH_REQUEST);
     translate(SSH2_MSG_USERAUTH_FAILURE);
     translate(SSH2_MSG_USERAUTH_SUCCESS);
@@ -1417,6 +1424,7 @@ static int s_wrpkt_prepare(Ssh ssh, struct Packet *pkt, int *offset_p)
        zlib_compress_block(ssh->cs_comp_ctx,
                            pkt->data + 12, pkt->length - 12,
                            &compblk, &complen);
+       ssh_pkt_ensure(pkt, complen + 2);   /* just in case it's got bigger */
        memcpy(pkt->data + 12, compblk, complen);
        sfree(compblk);
        pkt->length = complen + 12;
@@ -1864,6 +1872,7 @@ static void ssh2_pkt_defer_noqueue(Ssh ssh, struct Packet *pkt, int noignore)
         * get encrypted with a known IV.
         */
        struct Packet *ipkt = ssh2_pkt_init(SSH2_MSG_IGNORE);
+       ssh2_pkt_addstring_start(ipkt);
        ssh2_pkt_defer_noqueue(ssh, ipkt, TRUE);
     }
     len = ssh2_pkt_construct(ssh, pkt);
@@ -5104,9 +5113,10 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
        const struct ssh_mac *scmac_tobe;
        const struct ssh_compress *cscomp_tobe;
        const struct ssh_compress *sccomp_tobe;
-       char *hostkeydata, *sigdata, *keystr, *fingerprint;
-       int hostkeylen, siglen;
+       char *hostkeydata, *sigdata, *rsakeydata, *keystr, *fingerprint;
+       int hostkeylen, siglen, rsakeylen;
        void *hkey;                    /* actual host key */
+       void *rsakey;                  /* for RSA kex */
        unsigned char exchange_hash[SSH2_KEX_MAX_HASH_LEN];
        int n_preferred_kex;
        const struct ssh_kexes *preferred_kex[KEX_MAX];
@@ -5160,6 +5170,10 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
                s->preferred_kex[s->n_preferred_kex++] =
                    &ssh_diffiehellman_group1;
                break;
+             case KEX_RSA:
+               s->preferred_kex[s->n_preferred_kex++] =
+                   &ssh_rsa_kex;
+               break;
              case KEX_WARN:
                /* Flag for later. Don't bother if it's the last in
                 * the list. */
@@ -5559,107 +5573,217 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
            crWaitUntil(pktin);                /* Ignore packet */
     }
 
-    /*
-     * Work out the number of bits of key we will need from the key
-     * exchange. We start with the maximum key length of either
-     * cipher...
-     */
-    {
-       int csbits, scbits;
+    if (ssh->kex->main_type == KEXTYPE_DH) {
+        /*
+         * Work out the number of bits of key we will need from the
+         * key exchange. We start with the maximum key length of
+         * either cipher...
+         */
+        {
+            int csbits, scbits;
 
-       csbits = s->cscipher_tobe->keylen;
-       scbits = s->sccipher_tobe->keylen;
-       s->nbits = (csbits > scbits ? csbits : scbits);
-    }
-    /* The keys only have hlen-bit entropy, since they're based on
-     * a hash. So cap the key size at hlen bits. */
-    if (s->nbits > ssh->kex->hash->hlen * 8)
-       s->nbits = ssh->kex->hash->hlen * 8;
+            csbits = s->cscipher_tobe->keylen;
+            scbits = s->sccipher_tobe->keylen;
+            s->nbits = (csbits > scbits ? csbits : scbits);
+        }
+        /* The keys only have hlen-bit entropy, since they're based on
+         * a hash. So cap the key size at hlen bits. */
+        if (s->nbits > ssh->kex->hash->hlen * 8)
+            s->nbits = ssh->kex->hash->hlen * 8;
 
-    /*
-     * If we're doing Diffie-Hellman group exchange, start by
-     * requesting a group.
-     */
-    if (!ssh->kex->pdata) {
-       logevent("Doing Diffie-Hellman group exchange");
-       ssh->pkt_ctx |= SSH2_PKTCTX_DHGEX;
-       /*
-        * Work out how big a DH group we will need to allow that
-        * much data.
-        */
-       s->pbits = 512 << ((s->nbits - 1) / 64);
-       s->pktout = ssh2_pkt_init(SSH2_MSG_KEX_DH_GEX_REQUEST);
-       ssh2_pkt_adduint32(s->pktout, s->pbits);
-       ssh2_pkt_send_noqueue(ssh, s->pktout);
+        /*
+         * If we're doing Diffie-Hellman group exchange, start by
+         * requesting a group.
+         */
+        if (!ssh->kex->pdata) {
+            logevent("Doing Diffie-Hellman group exchange");
+            ssh->pkt_ctx |= SSH2_PKTCTX_DHGEX;
+            /*
+             * Work out how big a DH group we will need to allow that
+             * much data.
+             */
+            s->pbits = 512 << ((s->nbits - 1) / 64);
+            s->pktout = ssh2_pkt_init(SSH2_MSG_KEX_DH_GEX_REQUEST);
+            ssh2_pkt_adduint32(s->pktout, s->pbits);
+            ssh2_pkt_send_noqueue(ssh, s->pktout);
+
+            crWaitUntil(pktin);
+            if (pktin->type != SSH2_MSG_KEX_DH_GEX_GROUP) {
+                bombout(("expected key exchange group packet from server"));
+                crStop(0);
+            }
+            s->p = ssh2_pkt_getmp(pktin);
+            s->g = ssh2_pkt_getmp(pktin);
+            if (!s->p || !s->g) {
+                bombout(("unable to read mp-ints from incoming group packet"));
+                crStop(0);
+            }
+            ssh->kex_ctx = dh_setup_gex(s->p, s->g);
+            s->kex_init_value = SSH2_MSG_KEX_DH_GEX_INIT;
+            s->kex_reply_value = SSH2_MSG_KEX_DH_GEX_REPLY;
+        } else {
+            ssh->pkt_ctx |= SSH2_PKTCTX_DHGROUP;
+            ssh->kex_ctx = dh_setup_group(ssh->kex);
+            s->kex_init_value = SSH2_MSG_KEXDH_INIT;
+            s->kex_reply_value = SSH2_MSG_KEXDH_REPLY;
+            logeventf(ssh, "Using Diffie-Hellman with standard group \"%s\"",
+                      ssh->kex->groupname);
+        }
 
-       crWaitUntil(pktin);
-       if (pktin->type != SSH2_MSG_KEX_DH_GEX_GROUP) {
-           bombout(("expected key exchange group packet from server"));
-           crStop(0);
-       }
-       s->p = ssh2_pkt_getmp(pktin);
-       s->g = ssh2_pkt_getmp(pktin);
-       if (!s->p || !s->g) {
-           bombout(("unable to read mp-ints from incoming group packet"));
-           crStop(0);
-       }
-       ssh->kex_ctx = dh_setup_gex(s->p, s->g);
-       s->kex_init_value = SSH2_MSG_KEX_DH_GEX_INIT;
-       s->kex_reply_value = SSH2_MSG_KEX_DH_GEX_REPLY;
+        logeventf(ssh, "Doing Diffie-Hellman key exchange with hash %s",
+                  ssh->kex->hash->text_name);
+        /*
+         * Now generate and send e for Diffie-Hellman.
+         */
+        set_busy_status(ssh->frontend, BUSY_CPU); /* this can take a while */
+        s->e = dh_create_e(ssh->kex_ctx, s->nbits * 2);
+        s->pktout = ssh2_pkt_init(s->kex_init_value);
+        ssh2_pkt_addmp(s->pktout, s->e);
+        ssh2_pkt_send_noqueue(ssh, s->pktout);
+
+        set_busy_status(ssh->frontend, BUSY_WAITING); /* wait for server */
+        crWaitUntil(pktin);
+        if (pktin->type != s->kex_reply_value) {
+            bombout(("expected key exchange reply packet from server"));
+            crStop(0);
+        }
+        set_busy_status(ssh->frontend, BUSY_CPU); /* cogitate */
+        ssh_pkt_getstring(pktin, &s->hostkeydata, &s->hostkeylen);
+        s->hkey = ssh->hostkey->newkey(s->hostkeydata, s->hostkeylen);
+        s->f = ssh2_pkt_getmp(pktin);
+        if (!s->f) {
+            bombout(("unable to parse key exchange reply packet"));
+            crStop(0);
+        }
+        ssh_pkt_getstring(pktin, &s->sigdata, &s->siglen);
+
+        s->K = dh_find_K(ssh->kex_ctx, s->f);
+
+        /* We assume everything from now on will be quick, and it might
+         * involve user interaction. */
+        set_busy_status(ssh->frontend, BUSY_NOT);
+
+        hash_string(ssh->kex->hash, ssh->exhash, s->hostkeydata, s->hostkeylen);
+        if (!ssh->kex->pdata) {
+            hash_uint32(ssh->kex->hash, ssh->exhash, s->pbits);
+            hash_mpint(ssh->kex->hash, ssh->exhash, s->p);
+            hash_mpint(ssh->kex->hash, ssh->exhash, s->g);
+        }
+        hash_mpint(ssh->kex->hash, ssh->exhash, s->e);
+        hash_mpint(ssh->kex->hash, ssh->exhash, s->f);
+
+        dh_cleanup(ssh->kex_ctx);
+        freebn(s->f);
+        if (!ssh->kex->pdata) {
+            freebn(s->g);
+            freebn(s->p);
+        }
     } else {
-       ssh->pkt_ctx |= SSH2_PKTCTX_DHGROUP;
-       ssh->kex_ctx = dh_setup_group(ssh->kex);
-       s->kex_init_value = SSH2_MSG_KEXDH_INIT;
-       s->kex_reply_value = SSH2_MSG_KEXDH_REPLY;
-       logeventf(ssh, "Using Diffie-Hellman with standard group \"%s\"",
-                 ssh->kex->groupname);
-    }
+       logeventf(ssh, "Doing RSA key exchange with hash %s",
+                 ssh->kex->hash->text_name);
+       ssh->pkt_ctx |= SSH2_PKTCTX_RSAKEX;
+        /*
+         * RSA key exchange. First expect a KEXRSA_PUBKEY packet
+         * from the server.
+         */
+        crWaitUntil(pktin);
+        if (pktin->type != SSH2_MSG_KEXRSA_PUBKEY) {
+            bombout(("expected RSA public key packet from server"));
+            crStop(0);
+        }
 
-    logeventf(ssh, "Doing Diffie-Hellman key exchange with hash %s",
-             ssh->kex->hash->text_name);
-    /*
-     * Now generate and send e for Diffie-Hellman.
-     */
-    set_busy_status(ssh->frontend, BUSY_CPU); /* this can take a while */
-    s->e = dh_create_e(ssh->kex_ctx, s->nbits * 2);
-    s->pktout = ssh2_pkt_init(s->kex_init_value);
-    ssh2_pkt_addmp(s->pktout, s->e);
-    ssh2_pkt_send_noqueue(ssh, s->pktout);
+        ssh_pkt_getstring(pktin, &s->hostkeydata, &s->hostkeylen);
+        hash_string(ssh->kex->hash, ssh->exhash,
+                   s->hostkeydata, s->hostkeylen);
+       s->hkey = ssh->hostkey->newkey(s->hostkeydata, s->hostkeylen);
 
-    set_busy_status(ssh->frontend, BUSY_WAITING); /* wait for server */
-    crWaitUntil(pktin);
-    if (pktin->type != s->kex_reply_value) {
-       bombout(("expected key exchange reply packet from server"));
-       crStop(0);
-    }
-    set_busy_status(ssh->frontend, BUSY_CPU); /* cogitate */
-    ssh_pkt_getstring(pktin, &s->hostkeydata, &s->hostkeylen);
-    s->f = ssh2_pkt_getmp(pktin);
-    if (!s->f) {
-       bombout(("unable to parse key exchange reply packet"));
-       crStop(0);
-    }
-    ssh_pkt_getstring(pktin, &s->sigdata, &s->siglen);
+        {
+            char *keydata;
+            ssh_pkt_getstring(pktin, &keydata, &s->rsakeylen);
+            s->rsakeydata = snewn(s->rsakeylen, char);
+            memcpy(s->rsakeydata, keydata, s->rsakeylen);
+        }
+
+        s->rsakey = ssh_rsakex_newkey(s->rsakeydata, s->rsakeylen);
+        if (!s->rsakey) {
+            sfree(s->rsakeydata);
+            bombout(("unable to parse RSA public key from server"));
+            crStop(0);
+        }
+
+        hash_string(ssh->kex->hash, ssh->exhash, s->rsakeydata, s->rsakeylen);
+
+        /*
+         * Next, set up a shared secret K, of precisely KLEN -
+         * 2*HLEN - 49 bits, where KLEN is the bit length of the
+         * RSA key modulus and HLEN is the bit length of the hash
+         * we're using.
+         */
+        {
+            int klen = ssh_rsakex_klen(s->rsakey);
+            int nbits = klen - (2*ssh->kex->hash->hlen*8 + 49);
+            int i, byte = 0;
+            unsigned char *kstr1, *kstr2, *outstr;
+            int kstr1len, kstr2len, outstrlen;
+
+            s->K = bn_power_2(nbits - 1);
+
+            for (i = 0; i < nbits; i++) {
+                if ((i & 7) == 0) {
+                    byte = random_byte();
+                }
+                bignum_set_bit(s->K, i, (byte >> (i & 7)) & 1);
+            }
+
+            /*
+             * Encode this as an mpint.
+             */
+            kstr1 = ssh2_mpint_fmt(s->K, &kstr1len);
+            kstr2 = snewn(kstr2len = 4 + kstr1len, unsigned char);
+            PUT_32BIT(kstr2, kstr1len);
+            memcpy(kstr2 + 4, kstr1, kstr1len);
+
+            /*
+             * Encrypt it with the given RSA key.
+             */
+            outstrlen = (klen + 7) / 8;
+            outstr = snewn(outstrlen, unsigned char);
+            ssh_rsakex_encrypt(ssh->kex->hash, kstr2, kstr2len,
+                              outstr, outstrlen, s->rsakey);
+
+            /*
+             * And send it off in a return packet.
+             */
+            s->pktout = ssh2_pkt_init(SSH2_MSG_KEXRSA_SECRET);
+            ssh2_pkt_addstring_start(s->pktout);
+            ssh2_pkt_addstring_data(s->pktout, (char *)outstr, outstrlen);
+            ssh2_pkt_send_noqueue(ssh, s->pktout);
 
-    s->K = dh_find_K(ssh->kex_ctx, s->f);
+           hash_string(ssh->kex->hash, ssh->exhash, outstr, outstrlen);
 
-    /* We assume everything from now on will be quick, and it might
-     * involve user interaction. */
-    set_busy_status(ssh->frontend, BUSY_NOT);
+            sfree(kstr2);
+            sfree(kstr1);
+            sfree(outstr);
+        }
 
-    hash_string(ssh->kex->hash, ssh->exhash, s->hostkeydata, s->hostkeylen);
-    if (!ssh->kex->pdata) {
-       hash_uint32(ssh->kex->hash, ssh->exhash, s->pbits);
-       hash_mpint(ssh->kex->hash, ssh->exhash, s->p);
-       hash_mpint(ssh->kex->hash, ssh->exhash, s->g);
+        ssh_rsakex_freekey(s->rsakey);
+
+        crWaitUntil(pktin);
+        if (pktin->type != SSH2_MSG_KEXRSA_DONE) {
+            sfree(s->rsakeydata);
+            bombout(("expected signature packet from server"));
+            crStop(0);
+        }
+
+        ssh_pkt_getstring(pktin, &s->sigdata, &s->siglen);
+
+        sfree(s->rsakeydata);
     }
-    hash_mpint(ssh->kex->hash, ssh->exhash, s->e);
-    hash_mpint(ssh->kex->hash, ssh->exhash, s->f);
+
     hash_mpint(ssh->kex->hash, ssh->exhash, s->K);
     assert(ssh->kex->hash->hlen <= sizeof(s->exchange_hash));
     ssh->kex->hash->final(ssh->exhash, s->exchange_hash);
 
-    dh_cleanup(ssh->kex_ctx);
     ssh->kex_ctx = NULL;
 
 #if 0
@@ -5667,7 +5791,6 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
     dmemdump(s->exchange_hash, ssh->kex->hash->hlen);
 #endif
 
-    s->hkey = ssh->hostkey->newkey(s->hostkeydata, s->hostkeylen);
     if (!s->hkey ||
        !ssh->hostkey->verifysig(s->hkey, s->sigdata, s->siglen,
                                 (char *)s->exchange_hash,
@@ -5849,14 +5972,9 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
                  ssh->sccomp->text_name);
 
     /*
-     * Free key exchange data.
+     * Free shared secret.
      */
-    freebn(s->f);
     freebn(s->K);
-    if (!ssh->kex->pdata) {
-       freebn(s->g);
-       freebn(s->p);
-    }
 
     /*
      * Key exchange is over. Loop straight back round if we have a
@@ -6397,71 +6515,49 @@ static void ssh2_msg_channel_request(Ssh ssh, struct Packet *pktin)
 
                    if (0)
                        ;
+#define TRANSLATE_SIGNAL(s) \
+    else if (siglen == lenof(#s)-1 && !memcmp(sig, #s, siglen)) \
+        ssh->exitcode = 128 + SIG ## s
 #ifdef SIGABRT
-                   else if (siglen == lenof("ABRT")-1 &&
-                            !memcmp(sig, "ABRT", siglen))
-                       ssh->exitcode = 128 + SIGABRT;
+                   TRANSLATE_SIGNAL(ABRT);
 #endif
 #ifdef SIGALRM
-                   else if (siglen == lenof("ALRM")-1 &&
-                            !memcmp(sig, "ALRM", siglen))
-                       ssh->exitcode = 128 + SIGALRM;
+                   TRANSLATE_SIGNAL(ALRM);
 #endif
 #ifdef SIGFPE
-                   else if (siglen == lenof("FPE")-1 &&
-                            !memcmp(sig, "FPE", siglen))
-                       ssh->exitcode = 128 + SIGFPE;
+                   TRANSLATE_SIGNAL(FPE);
 #endif
 #ifdef SIGHUP
-                   else if (siglen == lenof("HUP")-1 &&
-                            !memcmp(sig, "HUP", siglen))
-                       ssh->exitcode = 128 + SIGHUP;
+                   TRANSLATE_SIGNAL(HUP);
 #endif
 #ifdef SIGILL
-                   else if (siglen == lenof("ILL")-1 &&
-                            !memcmp(sig, "ILL", siglen))
-                       ssh->exitcode = 128 + SIGILL;
+                   TRANSLATE_SIGNAL(ILL);
 #endif
 #ifdef SIGINT
-                   else if (siglen == lenof("INT")-1 &&
-                            !memcmp(sig, "INT", siglen))
-                       ssh->exitcode = 128 + SIGINT;
+                   TRANSLATE_SIGNAL(INT);
 #endif
 #ifdef SIGKILL
-                   else if (siglen == lenof("KILL")-1 &&
-                            !memcmp(sig, "KILL", siglen))
-                       ssh->exitcode = 128 + SIGKILL;
+                   TRANSLATE_SIGNAL(KILL);
 #endif
 #ifdef SIGPIPE
-                   else if (siglen == lenof("PIPE")-1 &&
-                            !memcmp(sig, "PIPE", siglen))
-                       ssh->exitcode = 128 + SIGPIPE;
+                   TRANSLATE_SIGNAL(PIPE);
 #endif
 #ifdef SIGQUIT
-                   else if (siglen == lenof("QUIT")-1 &&
-                            !memcmp(sig, "QUIT", siglen))
-                       ssh->exitcode = 128 + SIGQUIT;
+                   TRANSLATE_SIGNAL(QUIT);
 #endif
 #ifdef SIGSEGV
-                   else if (siglen == lenof("SEGV")-1 &&
-                            !memcmp(sig, "SEGV", siglen))
-                       ssh->exitcode = 128 + SIGSEGV;
+                   TRANSLATE_SIGNAL(SEGV);
 #endif
 #ifdef SIGTERM
-                   else if (siglen == lenof("TERM")-1 &&
-                            !memcmp(sig, "TERM", siglen))
-                       ssh->exitcode = 128 + SIGTERM;
+                   TRANSLATE_SIGNAL(TERM);
 #endif
 #ifdef SIGUSR1
-                   else if (siglen == lenof("USR1")-1 &&
-                            !memcmp(sig, "USR1", siglen))
-                       ssh->exitcode = 128 + SIGUSR1;
+                   TRANSLATE_SIGNAL(USR1);
 #endif
 #ifdef SIGUSR2
-                   else if (siglen == lenof("USR2")-1 &&
-                            !memcmp(sig, "USR2", siglen))
-                       ssh->exitcode = 128 + SIGUSR2;
+                   TRANSLATE_SIGNAL(USR2);
 #endif
+#undef TRANSLATE_SIGNAL
                    else
                        ssh->exitcode = 128;
                }
@@ -8926,8 +9022,7 @@ static void ssh_unthrottle(void *handle, int bufsize)
            ssh1_throttle(ssh, -1);
        }
     } else {
-       if (ssh->mainchan && ssh->mainchan->closes == 0)
-           ssh2_set_window(ssh->mainchan, OUR_V2_WINSIZE - bufsize);
+       ssh2_set_window(ssh->mainchan, OUR_V2_WINSIZE - bufsize);
     }
 }