Line discipline module now uses dynamically allocated data. Also
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 115e902..0a815bb 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -371,19 +371,28 @@ const static struct ssh_kex *kex_algs[] = {
 
 const static struct ssh_signkey *hostkey_algs[] = { &ssh_rsa, &ssh_dss };
 
-static void nullmac_key(unsigned char *key)
+static void *nullmac_make_context(void)
+{
+    return NULL;
+}
+static void nullmac_free_context(void *handle)
+{
+}
+static void nullmac_key(void *handle, unsigned char *key)
 {
 }
-static void nullmac_generate(unsigned char *blk, int len,
+static void nullmac_generate(void *handle, unsigned char *blk, int len,
                             unsigned long seq)
 {
 }
-static int nullmac_verify(unsigned char *blk, int len, unsigned long seq)
+static int nullmac_verify(void *handle, unsigned char *blk, int len,
+                         unsigned long seq)
 {
     return 1;
 }
 const static struct ssh_mac ssh_mac_none = {
-    nullmac_key, nullmac_key, nullmac_generate, nullmac_verify, "none", 0
+    nullmac_make_context, nullmac_free_context, nullmac_key,
+    nullmac_generate, nullmac_verify, "none", 0
 };
 const static struct ssh_mac *macs[] = {
     &ssh_sha1, &ssh_md5, &ssh_mac_none
@@ -392,23 +401,27 @@ const static struct ssh_mac *buggymacs[] = {
     &ssh_sha1_buggy, &ssh_md5, &ssh_mac_none
 };
 
-static void ssh_comp_none_init(void)
+static void *ssh_comp_none_init(void)
+{
+    return NULL;
+}
+static void ssh_comp_none_cleanup(void *handle)
 {
 }
-static int ssh_comp_none_block(unsigned char *block, int len,
+static int ssh_comp_none_block(void *handle, unsigned char *block, int len,
                               unsigned char **outblock, int *outlen)
 {
     return 0;
 }
-static int ssh_comp_none_disable(void)
+static int ssh_comp_none_disable(void *handle)
 {
     return 0;
 }
 const static struct ssh_compress ssh_comp_none = {
     "none",
-    ssh_comp_none_init, ssh_comp_none_block,
-    ssh_comp_none_init, ssh_comp_none_block,
-    ssh_comp_none_disable
+    ssh_comp_none_init, ssh_comp_none_cleanup, ssh_comp_none_block,
+    ssh_comp_none_init, ssh_comp_none_cleanup, ssh_comp_none_block,
+    ssh_comp_none_disable, NULL
 };
 extern const struct ssh_compress ssh_zlib;
 const static struct ssh_compress *compressions[] = {
@@ -542,6 +555,8 @@ struct ssh_tag {
 
     Socket s;
 
+    void *ldisc;
+
     unsigned char session_key[32];
     int v1_compressing;
     int v1_remote_protoflags;
@@ -550,12 +565,18 @@ struct ssh_tag {
     int X11_fwd_enabled;
     int remote_bugs;
     const struct ssh_cipher *cipher;
+    void *v1_cipher_ctx;
+    void *crcda_ctx;
     const struct ssh2_cipher *cscipher, *sccipher;
+    void *cs_cipher_ctx, *sc_cipher_ctx;
     const struct ssh_mac *csmac, *scmac;
+    void *cs_mac_ctx, *sc_mac_ctx;
     const struct ssh_compress *cscomp, *sccomp;
+    void *cs_comp_ctx, *sc_comp_ctx;
     const struct ssh_kex *kex;
     const struct ssh_signkey *hostkey;
     unsigned char v2_session_id[20];
+    void *kex_ctx;
 
     char *savedhost;
     int savedport;
@@ -797,13 +818,14 @@ static int ssh1_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
        st->to_read -= st->chunk;
     }
 
-    if (ssh->cipher && detect_attack(ssh->pktin.data, st->biglen, NULL)) {
+    if (ssh->cipher && detect_attack(ssh->crcda_ctx, ssh->pktin.data,
+                                    st->biglen, NULL)) {
         bombout(("Network attack (CRC compensation) detected!"));
         crReturn(0);
     }
 
     if (ssh->cipher)
-       ssh->cipher->decrypt(ssh->pktin.data, st->biglen);
+       ssh->cipher->decrypt(ssh->v1_cipher_ctx, ssh->pktin.data, st->biglen);
 
     st->realcrc = crc32(ssh->pktin.data, st->biglen - 4);
     st->gotcrc = GET_32BIT(ssh->pktin.data + st->biglen - 4);
@@ -817,7 +839,8 @@ static int ssh1_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
     if (ssh->v1_compressing) {
        unsigned char *decompblk;
        int decomplen;
-       zlib_decompress_block(ssh->pktin.body - 1, ssh->pktin.length + 1,
+       zlib_decompress_block(ssh->sc_comp_ctx,
+                             ssh->pktin.body - 1, ssh->pktin.length + 1,
                              &decompblk, &decomplen);
 
        if (ssh->pktin.maxlen < st->pad + decomplen) {
@@ -917,7 +940,8 @@ static int ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
     }
 
     if (ssh->sccipher)
-       ssh->sccipher->decrypt(ssh->pktin.data, st->cipherblk);
+       ssh->sccipher->decrypt(ssh->sc_cipher_ctx,
+                              ssh->pktin.data, st->cipherblk);
 
     /*
      * Now get the length and padding figures.
@@ -968,14 +992,15 @@ static int ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
     }
     /* Decrypt everything _except_ the MAC. */
     if (ssh->sccipher)
-       ssh->sccipher->decrypt(ssh->pktin.data + st->cipherblk,
+       ssh->sccipher->decrypt(ssh->sc_cipher_ctx,
+                              ssh->pktin.data + st->cipherblk,
                               st->packetlen - st->cipherblk);
 
     /*
      * Check the MAC.
      */
     if (ssh->scmac
-       && !ssh->scmac->verify(ssh->pktin.data, st->len + 4,
+       && !ssh->scmac->verify(ssh->sc_mac_ctx, ssh->pktin.data, st->len + 4,
                               st->incoming_sequence)) {
        bombout(("Incorrect MAC received on packet"));
        crReturn(0);
@@ -989,7 +1014,8 @@ static int ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
        unsigned char *newpayload;
        int newlen;
        if (ssh->sccomp &&
-           ssh->sccomp->decompress(ssh->pktin.data + 5, ssh->pktin.length - 5,
+           ssh->sccomp->decompress(ssh->sc_comp_ctx,
+                                   ssh->pktin.data + 5, ssh->pktin.length - 5,
                                    &newpayload, &newlen)) {
            if (ssh->pktin.maxlen < newlen + 5) {
                ssh->pktin.maxlen = newlen + 5;
@@ -1154,7 +1180,8 @@ static int s_wrpkt_prepare(Ssh ssh)
     if (ssh->v1_compressing) {
        unsigned char *compblk;
        int complen;
-       zlib_compress_block(ssh->pktout.body - 1, ssh->pktout.length + 1,
+       zlib_compress_block(ssh->cs_comp_ctx,
+                           ssh->pktout.body - 1, ssh->pktout.length + 1,
                            &compblk, &complen);
        ssh1_pktout_size(ssh, complen - 1);
        memcpy(ssh->pktout.body - 1, compblk, complen);
@@ -1172,7 +1199,7 @@ static int s_wrpkt_prepare(Ssh ssh)
     PUT_32BIT(ssh->pktout.data, len);
 
     if (ssh->cipher)
-       ssh->cipher->encrypt(ssh->pktout.data + 4, biglen);
+       ssh->cipher->encrypt(ssh->v1_cipher_ctx, ssh->pktout.data + 4, biglen);
 
     return biglen + 4;
 }
@@ -1440,7 +1467,8 @@ static int ssh2_pkt_construct(Ssh ssh)
        unsigned char *newpayload;
        int newlen;
        if (ssh->cscomp &&
-           ssh->cscomp->compress(ssh->pktout.data + 5, ssh->pktout.length - 5,
+           ssh->cscomp->compress(ssh->cs_comp_ctx, ssh->pktout.data + 5,
+                                 ssh->pktout.length - 5,
                                  &newpayload, &newlen)) {
            ssh->pktout.length = 5;
            ssh2_pkt_adddata(ssh, newpayload, newlen);
@@ -1464,12 +1492,14 @@ static int ssh2_pkt_construct(Ssh ssh)
        ssh->pktout.data[ssh->pktout.length + i] = random_byte();
     PUT_32BIT(ssh->pktout.data, ssh->pktout.length + padding - 4);
     if (ssh->csmac)
-       ssh->csmac->generate(ssh->pktout.data, ssh->pktout.length + padding,
+       ssh->csmac->generate(ssh->cs_mac_ctx, ssh->pktout.data,
+                            ssh->pktout.length + padding,
                             ssh->v2_outgoing_sequence);
     ssh->v2_outgoing_sequence++;       /* whether or not we MACed */
 
     if (ssh->cscipher)
-       ssh->cscipher->encrypt(ssh->pktout.data, ssh->pktout.length + padding);
+       ssh->cscipher->encrypt(ssh->cs_cipher_ctx,
+                              ssh->pktout.data, ssh->pktout.length + padding);
 
     /* Ready-to-send packet starts at ssh->pktout.data. We return length. */
     return ssh->pktout.length + padding + maclen;
@@ -2353,7 +2383,16 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     ssh->cipher = (s->cipher_type == SSH_CIPHER_BLOWFISH ? &ssh_blowfish_ssh1 :
                   s->cipher_type == SSH_CIPHER_DES ? &ssh_des :
                   &ssh_3des);
-    ssh->cipher->sesskey(ssh->session_key);
+    ssh->v1_cipher_ctx = ssh->cipher->make_context();
+    ssh->cipher->sesskey(ssh->v1_cipher_ctx, ssh->session_key);
+    {
+       char buf[256];
+       sprintf(buf, "Initialised %.200s encryption", ssh->cipher->text_name);
+       logevent(buf);
+    }
+
+    ssh->crcda_ctx = crcda_make_context();
+    logevent("Installing CRC compensation attack detector");
 
     crWaitUntil(ispkt);
 
@@ -3163,8 +3202,10 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen, int ispkt)
        }
        logevent("Started compression");
        ssh->v1_compressing = TRUE;
-       zlib_compress_init();
-       zlib_decompress_init();
+       ssh->cs_comp_ctx = zlib_compress_init();
+       logevent("Initialised zlib (RFC1950) compression");
+       ssh->sc_comp_ctx = zlib_decompress_init();
+       logevent("Initialised zlib (RFC1950) decompression");
     }
 
     /*
@@ -3194,7 +3235,8 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     if (ssh->eof_needed)
        ssh_special(ssh, TS_EOF);
 
-    ldisc_send(NULL, 0, 0);           /* cause ldisc to notice changes */
+    if (ssh->ldisc)
+       ldisc_send(ssh->ldisc, NULL, 0, 0);/* cause ldisc to notice changes */
     ssh->send_ok = 1;
     ssh->channels = newtree234(ssh_channelcmp);
     while (1) {
@@ -3926,12 +3968,12 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
        }
        s->p = ssh2_pkt_getmp(ssh);
        s->g = ssh2_pkt_getmp(ssh);
-       dh_setup_group(s->p, s->g);
+       ssh->kex_ctx = dh_setup_group(s->p, s->g);
        s->kex_init_value = SSH2_MSG_KEX_DH_GEX_INIT;
        s->kex_reply_value = SSH2_MSG_KEX_DH_GEX_REPLY;
     } else {
        ssh->pkt_ctx |= SSH2_PKTCTX_DHGROUP1;
-       dh_setup_group1();
+       ssh->kex_ctx = dh_setup_group1();
        s->kex_init_value = SSH2_MSG_KEXDH_INIT;
        s->kex_reply_value = SSH2_MSG_KEXDH_REPLY;
     }
@@ -3940,7 +3982,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     /*
      * Now generate and send e for Diffie-Hellman.
      */
-    s->e = dh_create_e(s->nbits * 2);
+    s->e = dh_create_e(ssh->kex_ctx, s->nbits * 2);
     ssh2_pkt_init(ssh, s->kex_init_value);
     ssh2_pkt_addmp(ssh, s->e);
     ssh2_pkt_send(ssh);
@@ -3954,7 +3996,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     s->f = ssh2_pkt_getmp(ssh);
     ssh2_pkt_getstring(ssh, &s->sigdata, &s->siglen);
 
-    s->K = dh_find_K(s->f);
+    s->K = dh_find_K(ssh->kex_ctx, s->f);
 
     sha_string(&ssh->exhash, s->hostkeydata, s->hostkeylen);
     if (ssh->kex == &ssh_diffiehellman_gex) {
@@ -3967,7 +4009,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     sha_mpint(&ssh->exhash, s->K);
     SHA_Final(&ssh->exhash, s->exchange_hash);
 
-    dh_cleanup();
+    dh_cleanup(ssh->kex_ctx);
 
 #if 0
     debug(("Exchange hash is:\n"));
@@ -4016,14 +4058,36 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     /*
      * Create and initialise session keys.
      */
+    if (ssh->cs_cipher_ctx)
+       ssh->cscipher->free_context(ssh->cs_cipher_ctx);
     ssh->cscipher = s->cscipher_tobe;
+    ssh->cs_cipher_ctx = ssh->cscipher->make_context();
+
+    if (ssh->sc_cipher_ctx)
+       ssh->sccipher->free_context(ssh->sc_cipher_ctx);
     ssh->sccipher = s->sccipher_tobe;
+    ssh->sc_cipher_ctx = ssh->sccipher->make_context();
+
+    if (ssh->cs_mac_ctx)
+       ssh->csmac->free_context(ssh->cs_mac_ctx);
     ssh->csmac = s->csmac_tobe;
+    ssh->cs_mac_ctx = ssh->csmac->make_context();
+
+    if (ssh->sc_mac_ctx)
+       ssh->scmac->free_context(ssh->sc_mac_ctx);
     ssh->scmac = s->scmac_tobe;
+    ssh->sc_mac_ctx = ssh->scmac->make_context();
+
+    if (ssh->cs_comp_ctx)
+       ssh->cscomp->compress_cleanup(ssh->cs_comp_ctx);
     ssh->cscomp = s->cscomp_tobe;
+    ssh->cs_comp_ctx = ssh->cscomp->compress_init();
+
+    if (ssh->sc_comp_ctx)
+       ssh->sccomp->decompress_cleanup(ssh->sc_comp_ctx);
     ssh->sccomp = s->sccomp_tobe;
-    ssh->cscomp->compress_init();
-    ssh->sccomp->decompress_init();
+    ssh->sc_comp_ctx = ssh->sccomp->decompress_init();
+
     /*
      * Set IVs after keys. Here we use the exchange hash from the
      * _first_ key exchange.
@@ -4034,18 +4098,38 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
            memcpy(ssh->v2_session_id, s->exchange_hash,
                   sizeof(s->exchange_hash));
        ssh2_mkkey(ssh,s->K,s->exchange_hash,ssh->v2_session_id,'C',keyspace);
-       ssh->cscipher->setcskey(keyspace);
+       ssh->cscipher->setkey(ssh->cs_cipher_ctx, keyspace);
        ssh2_mkkey(ssh,s->K,s->exchange_hash,ssh->v2_session_id,'D',keyspace);
-       ssh->sccipher->setsckey(keyspace);
+       ssh->sccipher->setkey(ssh->sc_cipher_ctx, keyspace);
        ssh2_mkkey(ssh,s->K,s->exchange_hash,ssh->v2_session_id,'A',keyspace);
-       ssh->cscipher->setcsiv(keyspace);
+       ssh->cscipher->setiv(ssh->cs_cipher_ctx, keyspace);
        ssh2_mkkey(ssh,s->K,s->exchange_hash,ssh->v2_session_id,'B',keyspace);
-       ssh->sccipher->setsciv(keyspace);
+       ssh->sccipher->setiv(ssh->sc_cipher_ctx, keyspace);
        ssh2_mkkey(ssh,s->K,s->exchange_hash,ssh->v2_session_id,'E',keyspace);
-       ssh->csmac->setcskey(keyspace);
+       ssh->csmac->setkey(ssh->cs_mac_ctx, keyspace);
        ssh2_mkkey(ssh,s->K,s->exchange_hash,ssh->v2_session_id,'F',keyspace);
-       ssh->scmac->setsckey(keyspace);
+       ssh->scmac->setkey(ssh->sc_mac_ctx, keyspace);
     }
+    {
+       char buf[256];
+       sprintf(buf, "Initialised %.200s client->server encryption",
+               ssh->cscipher->text_name);
+       logevent(buf);
+       sprintf(buf, "Initialised %.200s server->client encryption",
+               ssh->sccipher->text_name);
+       logevent(buf);
+       if (ssh->cscomp->text_name) {
+           sprintf(buf, "Initialised %.200s compression",
+                   ssh->cscomp->text_name);
+           logevent(buf);
+       }
+       if (ssh->sccomp->text_name) {
+           sprintf(buf, "Initialised %.200s decompression",
+                   ssh->sccomp->text_name);
+           logevent(buf);
+       }
+    }
+
 
     /*
      * If this is the first key exchange phase, we must pass the
@@ -4872,7 +4956,8 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                         * bytes we should adjust our string length
                         * by.
                         */
-                       stringlen -= ssh->cscomp->disable_compression();
+                       stringlen -= 
+                           ssh->cscomp->disable_compression(ssh->cs_comp_ctx);
                    }
                    ssh2_pkt_init(ssh, SSH2_MSG_IGNORE);
                    ssh2_pkt_addstring_start(ssh);
@@ -5297,7 +5382,8 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     /*
      * Transfer data!
      */
-    ldisc_send(NULL, 0, 0);           /* cause ldisc to notice changes */
+    if (ssh->ldisc)
+       ldisc_send(ssh->ldisc, NULL, 0, 0);/* cause ldisc to notice changes */
     ssh->send_ok = 1;
     while (1) {
        crReturnV;
@@ -5752,18 +5838,27 @@ static char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh = smalloc(sizeof(*ssh));
     ssh->s = NULL;
     ssh->cipher = NULL;
+    ssh->v1_cipher_ctx = NULL;
+    ssh->crcda_ctx = NULL;
     ssh->cscipher = NULL;
+    ssh->cs_cipher_ctx = NULL;
     ssh->sccipher = NULL;
+    ssh->sc_cipher_ctx = NULL;
     ssh->csmac = NULL;
+    ssh->sc_mac_ctx = NULL;
     ssh->scmac = NULL;
+    ssh->sc_mac_ctx = NULL;
     ssh->cscomp = NULL;
+    ssh->cs_comp_ctx = NULL;
     ssh->sccomp = NULL;
+    ssh->sc_comp_ctx = NULL;
     ssh->kex = NULL;
     ssh->hostkey = NULL;
     ssh->exitcode = -1;
     ssh->state = SSH_STATE_PREPACKET;
     ssh->size_needed = FALSE;
     ssh->eof_needed = FALSE;
+    ssh->ldisc = NULL;
     {
        static const struct Packet empty = { 0, 0, NULL, NULL, 0 };
        ssh->pktin = ssh->pktout = empty;
@@ -6043,6 +6138,12 @@ static int ssh_ldisc(void *handle, int option)
     return FALSE;
 }
 
+static void ssh_provide_ldisc(void *handle, void *ldisc)
+{
+    Ssh ssh = (Ssh) handle;
+    ssh->ldisc = ldisc;
+}
+
 static int ssh_return_exitcode(void *handle)
 {
     Ssh ssh = (Ssh) handle;
@@ -6070,6 +6171,7 @@ Backend ssh_backend = {
     ssh_return_exitcode,
     ssh_sendok,
     ssh_ldisc,
+    ssh_provide_ldisc,
     ssh_unthrottle,
     22
 };