Experimental Rlogin support, thanks to Delian Delchev. Local flow
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index e86dff6..f75041b 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -231,7 +231,7 @@ struct Packet {
     long maxlen;
 };
 
-static SHA_State exhash;
+static SHA_State exhash, exhashbase;
 
 static Socket s = NULL;
 
@@ -257,13 +257,14 @@ static tree234 *ssh_channels;           /* indexed by local id */
 static struct ssh_channel *mainchan;   /* primary session channel */
 
 static enum {
+    SSH_STATE_PREPACKET,
     SSH_STATE_BEFORE_SIZE,
     SSH_STATE_INTERMED,
     SSH_STATE_SESSION,
     SSH_STATE_CLOSED
-} ssh_state = SSH_STATE_BEFORE_SIZE;
+} ssh_state = SSH_STATE_PREPACKET;
 
-static int size_needed = FALSE;
+static int size_needed = FALSE, eof_needed = FALSE;
 
 static struct Packet pktin = { 0, 0, NULL, NULL, 0 };
 static struct Packet pktout = { 0, 0, NULL, NULL, 0 };
@@ -273,6 +274,7 @@ static void (*ssh_protocol)(unsigned char *in, int inlen, int ispkt);
 static void ssh1_protocol(unsigned char *in, int inlen, int ispkt);
 static void ssh2_protocol(unsigned char *in, int inlen, int ispkt);
 static void ssh_size(void);
+static void ssh_special (Telnet_Special);
 
 static int (*s_rdpkt)(unsigned char **data, int *datalen);
 
@@ -410,6 +412,7 @@ next_packet:
        if (pktin.maxlen < st->pad + decomplen) {
            pktin.maxlen = st->pad + decomplen;
            pktin.data = srealloc(pktin.data, pktin.maxlen+APIEXTRA);
+            pktin.body = pktin.data + st->pad + 1;
            if (!pktin.data)
                fatalbox("Out of memory");
        }
@@ -1042,12 +1045,12 @@ static int do_ssh_init(unsigned char c) {
          * This is a v2 server. Begin v2 protocol.
          */
         char *verstring = "SSH-2.0-PuTTY";
-        SHA_Init(&exhash);
+        SHA_Init(&exhashbase);
         /*
          * Hash our version string and their version string.
          */
-        sha_string(&exhash, verstring, strlen(verstring));
-        sha_string(&exhash, vstring, strcspn(vstring, "\r\n"));
+        sha_string(&exhashbase, verstring, strlen(verstring));
+        sha_string(&exhashbase, vstring, strcspn(vstring, "\r\n"));
         sprintf(vstring, "%s\n", verstring);
         sprintf(vlog, "We claim version: %s", verstring);
         logevent(vlog);
@@ -1071,6 +1074,7 @@ static int do_ssh_init(unsigned char c) {
         ssh_version = 1;
         s_rdpkt = ssh1_rdpkt;
     }
+    ssh_state = SSH_STATE_BEFORE_SIZE;
 
     crFinish(0);
 }
@@ -1120,6 +1124,7 @@ static void ssh_gotdata(unsigned char *data, int datalen)
 static int ssh_receive(Socket skt, int urgent, char *data, int len) {
     if (!len) {
        /* Connection has closed. */
+       ssh_state = SSH_STATE_CLOSED;
        sk_close(s);
        s = NULL;
        return 0;
@@ -1179,7 +1184,7 @@ static char *connect_to_host(char *host, int port, char **realhost)
     /*
      * Open socket.
      */
-    s = sk_new(addr, port, ssh_receive);
+    s = sk_new(addr, port, 0, ssh_receive);
     if ( (err = sk_socket_error(s)) )
        return err;
 
@@ -1782,6 +1787,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
     ssh_state = SSH_STATE_SESSION;
     if (size_needed)
        ssh_size();
+    if (eof_needed)
+        ssh_special(TS_EOF);
 
     ssh_send_ok = 1;
     ssh_channels = newtree234(ssh_channelcmp);
@@ -1992,9 +1999,11 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     static unsigned char keyspace[40];
     static const struct ssh_cipher *preferred_cipher;
     static const struct ssh_compress *preferred_comp;
+    static int first_kex;
 
     crBegin;
     random_init();
+    first_kex = 1;
 
     /*
      * Set up the preferred cipher and compression.
@@ -2098,7 +2107,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     ssh2_pkt_addbool(FALSE);
     /* Reserved. */
     ssh2_pkt_adduint32(0);
+
+    exhash = exhashbase;
     sha_string(&exhash, pktout.data+5, pktout.length-5);
+
     ssh2_pkt_send();
 
     if (!ispkt) crWaitUntil(ispkt);
@@ -2267,13 +2279,26 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
      * Set IVs after keys.
      */
     ssh2_mkkey(K, exchange_hash, 'C', keyspace); cscipher->setcskey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'D', keyspace); cscipher->setsckey(keyspace);
+    ssh2_mkkey(K, exchange_hash, 'D', keyspace); sccipher->setsckey(keyspace);
     ssh2_mkkey(K, exchange_hash, 'A', keyspace); cscipher->setcsiv(keyspace);
     ssh2_mkkey(K, exchange_hash, 'B', keyspace); sccipher->setsciv(keyspace);
     ssh2_mkkey(K, exchange_hash, 'E', keyspace); csmac->setcskey(keyspace);
     ssh2_mkkey(K, exchange_hash, 'F', keyspace); scmac->setsckey(keyspace);
 
     /*
+     * If this is the first key exchange phase, we must pass the
+     * SSH2_MSG_NEWKEYS packet to the next layer, not because it
+     * wants to see it but because it will need time to initialise
+     * itself before it sees an actual packet. In subsequent key
+     * exchange phases, we don't pass SSH2_MSG_NEWKEYS on, because
+     * it would only confuse the layer above.
+     */
+    if (!first_kex) {
+        crReturn(0);
+    }
+    first_kex = 0;
+
+    /*
      * Now we're encrypting. Begin returning 1 to the protocol main
      * function so that other things can run on top of the
      * transport. If we ever see a KEXINIT, we must go back to the
@@ -2543,6 +2568,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
     ssh_state = SSH_STATE_SESSION;
     if (size_needed)
        ssh_size();
+    if (eof_needed)
+        ssh_special(TS_EOF);
 
     /*
      * Transfer data!
@@ -2707,6 +2734,7 @@ static void ssh_send (char *buf, int len) {
 static void ssh_size(void) {
     switch (ssh_state) {
       case SSH_STATE_BEFORE_SIZE:
+      case SSH_STATE_PREPACKET:
       case SSH_STATE_CLOSED:
        break;                         /* do nothing */
       case SSH_STATE_INTERMED:
@@ -2730,6 +2758,7 @@ static void ssh_size(void) {
                 ssh2_pkt_send();
             }
         }
+        break;
     }
 }
 
@@ -2740,6 +2769,15 @@ static void ssh_size(void) {
  */
 static void ssh_special (Telnet_Special code) {
     if (code == TS_EOF) {
+        if (ssh_state != SSH_STATE_SESSION) {
+            /*
+             * Buffer the EOF in case we are pre-SESSION, so we can
+             * send it as soon as we reach SESSION.
+             */
+            if (code == TS_EOF)
+                eof_needed = TRUE;
+            return;
+        }
         if (ssh_version == 1) {
             send_packet(SSH1_CMSG_EOF, PKT_END);
         } else {
@@ -2749,6 +2787,8 @@ static void ssh_special (Telnet_Special code) {
         }
         logevent("Sent EOF message");
     } else if (code == TS_PING) {
+        if (ssh_state == SSH_STATE_CLOSED || ssh_state == SSH_STATE_PREPACKET)
+            return;
         if (ssh_version == 1) {
             send_packet(SSH1_MSG_IGNORE, PKT_STR, "", PKT_END);
         } else {