X-Git-Url: https://git.distorted.org.uk/u/mdw/putty/blobdiff_plain/db7d555c2c61c57c869329c34ba3bcb5d4f2f0db..c91409da0ac0d3fb4a225ab85e14370514e4094e:/ssh.c diff --git a/ssh.c b/ssh.c index e7518b6a..f75041ba 100644 --- a/ssh.c +++ b/ssh.c @@ -171,8 +171,8 @@ const static struct ssh_cipher *ciphers[] = { &ssh_blowfish_ssh2, &ssh_3des_ssh2 extern const struct ssh_kex ssh_diffiehellman; const static struct ssh_kex *kex_algs[] = { &ssh_diffiehellman }; -extern const struct ssh_hostkey ssh_dss; -const static struct ssh_hostkey *hostkey_algs[] = { &ssh_dss }; +extern const struct ssh_signkey ssh_dss; +const static struct ssh_signkey *hostkey_algs[] = { &ssh_dss }; extern const struct ssh_mac ssh_md5, ssh_sha1, ssh_sha1_buggy; @@ -231,7 +231,7 @@ struct Packet { long maxlen; }; -static SHA_State exhash; +static SHA_State exhash, exhashbase; static Socket s = NULL; @@ -246,7 +246,7 @@ static const struct ssh_mac *scmac = NULL; static const struct ssh_compress *cscomp = NULL; static const struct ssh_compress *sccomp = NULL; static const struct ssh_kex *kex = NULL; -static const struct ssh_hostkey *hostkey = NULL; +static const struct ssh_signkey *hostkey = NULL; int (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL; static char *savedhost; @@ -257,13 +257,14 @@ static tree234 *ssh_channels; /* indexed by local id */ static struct ssh_channel *mainchan; /* primary session channel */ static enum { + SSH_STATE_PREPACKET, SSH_STATE_BEFORE_SIZE, SSH_STATE_INTERMED, SSH_STATE_SESSION, SSH_STATE_CLOSED -} ssh_state = SSH_STATE_BEFORE_SIZE; +} ssh_state = SSH_STATE_PREPACKET; -static int size_needed = FALSE; +static int size_needed = FALSE, eof_needed = FALSE; static struct Packet pktin = { 0, 0, NULL, NULL, 0 }; static struct Packet pktout = { 0, 0, NULL, NULL, 0 }; @@ -273,6 +274,7 @@ static void (*ssh_protocol)(unsigned char *in, int inlen, int ispkt); static void ssh1_protocol(unsigned char *in, int inlen, int ispkt); static void ssh2_protocol(unsigned char *in, int inlen, int ispkt); static void ssh_size(void); +static void ssh_special (Telnet_Special); static int (*s_rdpkt)(unsigned char **data, int *datalen); @@ -355,8 +357,8 @@ next_packet: if (pktin.maxlen < st->biglen) { pktin.maxlen = st->biglen; - pktin.data = (pktin.data == NULL ? malloc(st->biglen+APIEXTRA) : - realloc(pktin.data, st->biglen+APIEXTRA)); + pktin.data = (pktin.data == NULL ? smalloc(st->biglen+APIEXTRA) : + srealloc(pktin.data, st->biglen+APIEXTRA)); if (!pktin.data) fatalbox("Out of memory"); } @@ -409,13 +411,14 @@ next_packet: if (pktin.maxlen < st->pad + decomplen) { pktin.maxlen = st->pad + decomplen; - pktin.data = realloc(pktin.data, pktin.maxlen+APIEXTRA); + pktin.data = srealloc(pktin.data, pktin.maxlen+APIEXTRA); + pktin.body = pktin.data + st->pad + 1; if (!pktin.data) fatalbox("Out of memory"); } memcpy(pktin.body-1, decompblk, decomplen); - free(decompblk); + sfree(decompblk); pktin.length = decomplen-1; #if 0 debug(("Packet payload post-decompression:\n")); @@ -475,8 +478,8 @@ next_packet: if (pktin.maxlen < st->cipherblk) { pktin.maxlen = st->cipherblk; - pktin.data = (pktin.data == NULL ? malloc(st->cipherblk+APIEXTRA) : - realloc(pktin.data, st->cipherblk+APIEXTRA)); + pktin.data = (pktin.data == NULL ? smalloc(st->cipherblk+APIEXTRA) : + srealloc(pktin.data, st->cipherblk+APIEXTRA)); if (!pktin.data) fatalbox("Out of memory"); } @@ -523,8 +526,8 @@ next_packet: */ if (pktin.maxlen < st->packetlen+st->maclen) { pktin.maxlen = st->packetlen+st->maclen; - pktin.data = (pktin.data == NULL ? malloc(pktin.maxlen+APIEXTRA) : - realloc(pktin.data, pktin.maxlen+APIEXTRA)); + pktin.data = (pktin.data == NULL ? smalloc(pktin.maxlen+APIEXTRA) : + srealloc(pktin.data, pktin.maxlen+APIEXTRA)); if (!pktin.data) fatalbox("Out of memory"); } @@ -569,8 +572,8 @@ next_packet: &newpayload, &newlen)) { if (pktin.maxlen < newlen+5) { pktin.maxlen = newlen+5; - pktin.data = (pktin.data == NULL ? malloc(pktin.maxlen+APIEXTRA) : - realloc(pktin.data, pktin.maxlen+APIEXTRA)); + pktin.data = (pktin.data == NULL ? smalloc(pktin.maxlen+APIEXTRA) : + srealloc(pktin.data, pktin.maxlen+APIEXTRA)); if (!pktin.data) fatalbox("Out of memory"); } @@ -583,7 +586,7 @@ next_packet: debug(("\r\n")); #endif - free(newpayload); + sfree(newpayload); } } @@ -609,11 +612,11 @@ static void ssh1_pktout_size(int len) { #ifdef MSCRYPTOAPI /* Allocate enough buffer space for extra block * for MS CryptEncrypt() */ - pktout.data = (pktout.data == NULL ? malloc(biglen+12) : - realloc(pktout.data, biglen+12)); + pktout.data = (pktout.data == NULL ? smalloc(biglen+12) : + srealloc(pktout.data, biglen+12)); #else - pktout.data = (pktout.data == NULL ? malloc(biglen+4) : - realloc(pktout.data, biglen+4)); + pktout.data = (pktout.data == NULL ? smalloc(biglen+4) : + srealloc(pktout.data, biglen+4)); #endif if (!pktout.data) fatalbox("Out of memory"); @@ -645,7 +648,7 @@ static void s_wrpkt(void) { &compblk, &complen); ssh1_pktout_size(complen-1); memcpy(pktout.body-1, compblk, complen); - free(compblk); + sfree(compblk); #if 0 debug(("Packet payload post-compression:\n")); for (i = -1; i < pktout.length; i++) @@ -795,8 +798,8 @@ static void ssh2_pkt_adddata(void *data, int len) { pktout.length += len; if (pktout.maxlen < pktout.length) { pktout.maxlen = pktout.length + 256; - pktout.data = (pktout.data == NULL ? malloc(pktout.maxlen+APIEXTRA) : - realloc(pktout.data, pktout.maxlen+APIEXTRA)); + pktout.data = (pktout.data == NULL ? smalloc(pktout.maxlen+APIEXTRA) : + srealloc(pktout.data, pktout.maxlen+APIEXTRA)); if (!pktout.data) fatalbox("Out of memory"); } @@ -838,7 +841,7 @@ static void ssh2_pkt_addstring(char *data) { static char *ssh2_mpint_fmt(Bignum b, int *len) { unsigned char *p; int i, n = b[0]; - p = malloc(n * 2 + 1); + p = smalloc(n * 2 + 1); if (!p) fatalbox("out of memory"); p[0] = 0; @@ -859,7 +862,7 @@ static void ssh2_pkt_addmp(Bignum b) { p = ssh2_mpint_fmt(b, &len); ssh2_pkt_addstring_start(); ssh2_pkt_addstring_data(p, len); - free(p); + sfree(p); } static void ssh2_pkt_send(void) { int cipherblk, maclen, padding, i; @@ -881,7 +884,7 @@ static void ssh2_pkt_send(void) { &newpayload, &newlen)) { pktout.length = 5; ssh2_pkt_adddata(newpayload, newlen); - free(newpayload); + sfree(newpayload); } } @@ -925,7 +928,7 @@ void bndebug(char *string, Bignum b) { for (i = 0; i < len; i++) debug((" %02x", p[i])); debug(("\r\n")); - free(p); + sfree(p); } #endif @@ -934,7 +937,7 @@ static void sha_mpint(SHA_State *s, Bignum b) { int len; p = ssh2_mpint_fmt(b, &len); sha_string(s, p, len); - free(p); + sfree(p); } /* @@ -1042,12 +1045,12 @@ static int do_ssh_init(unsigned char c) { * This is a v2 server. Begin v2 protocol. */ char *verstring = "SSH-2.0-PuTTY"; - SHA_Init(&exhash); + SHA_Init(&exhashbase); /* * Hash our version string and their version string. */ - sha_string(&exhash, verstring, strlen(verstring)); - sha_string(&exhash, vstring, strcspn(vstring, "\r\n")); + sha_string(&exhashbase, verstring, strlen(verstring)); + sha_string(&exhashbase, vstring, strcspn(vstring, "\r\n")); sprintf(vstring, "%s\n", verstring); sprintf(vlog, "We claim version: %s", verstring); logevent(vlog); @@ -1071,6 +1074,7 @@ static int do_ssh_init(unsigned char c) { ssh_version = 1; s_rdpkt = ssh1_rdpkt; } + ssh_state = SSH_STATE_BEFORE_SIZE; crFinish(0); } @@ -1120,6 +1124,7 @@ static void ssh_gotdata(unsigned char *data, int datalen) static int ssh_receive(Socket skt, int urgent, char *data, int len) { if (!len) { /* Connection has closed. */ + ssh_state = SSH_STATE_CLOSED; sk_close(s); s = NULL; return 0; @@ -1149,7 +1154,7 @@ static char *connect_to_host(char *host, int port, char **realhost) int FWport; #endif - savedhost = malloc(1+strlen(host)); + savedhost = smalloc(1+strlen(host)); if (!savedhost) fatalbox("Out of memory"); strcpy(savedhost, host); @@ -1179,7 +1184,7 @@ static char *connect_to_host(char *host, int port, char **realhost) /* * Open socket. */ - s = sk_new(addr, port, ssh_receive); + s = sk_new(addr, port, 0, ssh_receive); if ( (err = sk_socket_error(s)) ) return err; @@ -1255,7 +1260,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) len = (hostkey.bytes > servkey.bytes ? hostkey.bytes : servkey.bytes); - rsabuf = malloc(len); + rsabuf = smalloc(len); if (!rsabuf) fatalbox("Out of memory"); @@ -1268,13 +1273,13 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) */ int len = rsastr_len(&hostkey); char fingerprint[100]; - char *keystr = malloc(len); + char *keystr = smalloc(len); if (!keystr) fatalbox("Out of memory"); rsastr_fmt(keystr, &hostkey); rsa_fingerprint(fingerprint, sizeof(fingerprint), &hostkey); verify_ssh_host_key(savedhost, savedport, "rsa", keystr, fingerprint); - free(keystr); + sfree(keystr); } for (i=0; i<32; i++) { @@ -1316,7 +1321,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) logevent("Trying to enable encryption..."); - free(rsabuf); + sfree(rsabuf); cipher = cipher_type == SSH_CIPHER_BLOWFISH ? &ssh_blowfish_ssh1 : cipher_type == SSH_CIPHER_DES ? &ssh_des : @@ -1460,7 +1465,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) len += ssh1_bignum_length(challenge); len += 16; /* session id */ len += 4; /* response format */ - agentreq = malloc(4 + len); + agentreq = smalloc(4 + len); PUT_32BIT(agentreq, len); q = agentreq + 4; *q++ = SSH_AGENTC_RSA_CHALLENGE; @@ -1472,13 +1477,13 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) memcpy(q, session_id, 16); q += 16; PUT_32BIT(q, 1); /* response format */ agent_query(agentreq, len+4, &ret, &retlen); - free(agentreq); + sfree(agentreq); if (ret) { if (ret[4] == SSH_AGENT_RSA_RESPONSE) { logevent("Sending Pageant's response"); send_packet(SSH1_CMSG_AUTH_RSA_RESPONSE, PKT_DATA, ret+5, 16, PKT_END); - free(ret); + sfree(ret); crWaitUntil(ispkt); if (pktin.type == SSH1_SMSG_SUCCESS) { logevent("Pageant's response accepted"); @@ -1493,7 +1498,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) logevent("Pageant's response not accepted"); } else { logevent("Pageant failed to answer challenge"); - free(ret); + sfree(ret); } } else { logevent("No reply received from Pageant"); @@ -1573,7 +1578,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt) goto tryauth; } sprintf(prompt, "Passphrase for key \"%.100s\": ", comment); - free(comment); + sfree(comment); } if (ssh_get_password) { @@ -1782,6 +1787,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) { ssh_state = SSH_STATE_SESSION; if (size_needed) ssh_size(); + if (eof_needed) + ssh_special(TS_EOF); ssh_send_ok = 1; ssh_channels = newtree234(ssh_channelcmp); @@ -1817,7 +1824,7 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) { break; /* found a free number */ i = c->localid + 1; } - c = malloc(sizeof(struct ssh_channel)); + c = smalloc(sizeof(struct ssh_channel)); c->remoteid = GET_32BIT(pktin.body); c->localid = i; c->closes = 0; @@ -1841,7 +1848,7 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) { c->closes |= closetype; if (c->closes == 3) { del234(ssh_channels, c); - free(c); + sfree(c); } } } else if (pktin.type == SSH1_MSG_CHANNEL_DATA) { @@ -1863,7 +1870,7 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) { } if (c->u.a.lensofar == 4) { c->u.a.totallen = 4 + GET_32BIT(c->u.a.msglen); - c->u.a.message = malloc(c->u.a.totallen); + c->u.a.message = smalloc(c->u.a.totallen); memcpy(c->u.a.message, c->u.a.msglen, 4); } if (c->u.a.lensofar >= 4 && len > 0) { @@ -1889,8 +1896,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) { PKT_DATA, sentreply, replylen, PKT_END); if (reply) - free(reply); - free(c->u.a.message); + sfree(reply); + sfree(c->u.a.message); c->u.a.lensofar = 0; } } @@ -1987,13 +1994,16 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt) static const struct ssh_compress *sccomp_tobe = NULL; static char *hostkeydata, *sigdata, *keystr, *fingerprint; static int hostkeylen, siglen; + static void *hkey; /* actual host key */ static unsigned char exchange_hash[20]; static unsigned char keyspace[40]; static const struct ssh_cipher *preferred_cipher; static const struct ssh_compress *preferred_comp; + static int first_kex; crBegin; random_init(); + first_kex = 1; /* * Set up the preferred cipher and compression. @@ -2097,7 +2107,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt) ssh2_pkt_addbool(FALSE); /* Reserved. */ ssh2_pkt_adduint32(0); + + exhash = exhashbase; sha_string(&exhash, pktout.data+5, pktout.length-5); + ssh2_pkt_send(); if (!ispkt) crWaitUntil(ispkt); @@ -2216,8 +2229,8 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt) debug(("\r\n")); #endif - hostkey->setkey(hostkeydata, hostkeylen); - if (!hostkey->verifysig(sigdata, siglen, exchange_hash, 20)) { + hkey = hostkey->newkey(hostkeydata, hostkeylen); + if (!hostkey->verifysig(hkey, sigdata, siglen, exchange_hash, 20)) { bombout(("Server failed host key check")); crReturn(0); } @@ -2235,14 +2248,15 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt) * Authenticate remote host: verify host key. (We've already * checked the signature of the exchange hash.) */ - keystr = hostkey->fmtkey(); - fingerprint = hostkey->fingerprint(); + keystr = hostkey->fmtkey(hkey); + fingerprint = hostkey->fingerprint(hkey); verify_ssh_host_key(savedhost, savedport, hostkey->keytype, keystr, fingerprint); logevent("Host key fingerprint is:"); logevent(fingerprint); - free(fingerprint); - free(keystr); + sfree(fingerprint); + sfree(keystr); + hostkey->freekey(hkey); /* * Send SSH2_MSG_NEWKEYS. @@ -2265,13 +2279,26 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt) * Set IVs after keys. */ ssh2_mkkey(K, exchange_hash, 'C', keyspace); cscipher->setcskey(keyspace); - ssh2_mkkey(K, exchange_hash, 'D', keyspace); cscipher->setsckey(keyspace); + ssh2_mkkey(K, exchange_hash, 'D', keyspace); sccipher->setsckey(keyspace); ssh2_mkkey(K, exchange_hash, 'A', keyspace); cscipher->setcsiv(keyspace); ssh2_mkkey(K, exchange_hash, 'B', keyspace); sccipher->setsciv(keyspace); ssh2_mkkey(K, exchange_hash, 'E', keyspace); csmac->setcskey(keyspace); ssh2_mkkey(K, exchange_hash, 'F', keyspace); scmac->setsckey(keyspace); /* + * If this is the first key exchange phase, we must pass the + * SSH2_MSG_NEWKEYS packet to the next layer, not because it + * wants to see it but because it will need time to initialise + * itself before it sees an actual packet. In subsequent key + * exchange phases, we don't pass SSH2_MSG_NEWKEYS on, because + * it would only confuse the layer above. + */ + if (!first_kex) { + crReturn(0); + } + first_kex = 0; + + /* * Now we're encrypting. Begin returning 1 to the protocol main * function so that other things can run on top of the * transport. If we ever see a KEXINIT, we must go back to the @@ -2440,7 +2467,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt) /* * So now create a channel with a session in it. */ - mainchan = malloc(sizeof(struct ssh_channel)); + mainchan = smalloc(sizeof(struct ssh_channel)); mainchan->localid = 100; /* as good as any */ ssh2_pkt_init(SSH2_MSG_CHANNEL_OPEN); ssh2_pkt_addstring("session"); @@ -2541,6 +2568,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt) ssh_state = SSH_STATE_SESSION; if (size_needed) ssh_size(); + if (eof_needed) + ssh_special(TS_EOF); /* * Transfer data! @@ -2705,6 +2734,7 @@ static void ssh_send (char *buf, int len) { static void ssh_size(void) { switch (ssh_state) { case SSH_STATE_BEFORE_SIZE: + case SSH_STATE_PREPACKET: case SSH_STATE_CLOSED: break; /* do nothing */ case SSH_STATE_INTERMED: @@ -2728,6 +2758,7 @@ static void ssh_size(void) { ssh2_pkt_send(); } } + break; } } @@ -2738,6 +2769,15 @@ static void ssh_size(void) { */ static void ssh_special (Telnet_Special code) { if (code == TS_EOF) { + if (ssh_state != SSH_STATE_SESSION) { + /* + * Buffer the EOF in case we are pre-SESSION, so we can + * send it as soon as we reach SESSION. + */ + if (code == TS_EOF) + eof_needed = TRUE; + return; + } if (ssh_version == 1) { send_packet(SSH1_CMSG_EOF, PKT_END); } else { @@ -2747,6 +2787,8 @@ static void ssh_special (Telnet_Special code) { } logevent("Sent EOF message"); } else if (code == TS_PING) { + if (ssh_state == SSH_STATE_CLOSED || ssh_state == SSH_STATE_PREPACKET) + return; if (ssh_version == 1) { send_packet(SSH1_MSG_IGNORE, PKT_STR, "", PKT_END); } else {