Robert de Bath's Big Patch, part 1
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index a355c29..65f0075 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -81,7 +81,7 @@ static SOCKET s = INVALID_SOCKET;
 static unsigned char session_key[32];
 static struct ssh_cipher *cipher = NULL;
 int scp_flags = 0;
-void (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL;
+int (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL;
 
 static char *savedhost;
 
@@ -128,16 +128,8 @@ static void c_write (char *buf, int len) {
        if (len > 0) { fwrite(buf, len, 1, stderr); fputc('\n', stderr); }
        return;
     }
-    while (len--) {
-       int new_head = (inbuf_head + 1) & INBUF_MASK;
-       if (new_head != inbuf_reap) {
-           inbuf[inbuf_head] = *buf++;
-           inbuf_head = new_head;
-       } else {
-            term_out();
-            if( inbuf_head == inbuf_reap ) len++; else break;
-       }
-    }
+    while (len--) 
+        c_write1(*buf++);
 }
 
 struct Packet {
@@ -267,8 +259,12 @@ next_packet:
 static void ssh_gotdata(unsigned char *data, int datalen)
 {
     while (datalen > 0) {
-       if ( s_rdpkt(&data, &datalen) == 0 )
+       if ( s_rdpkt(&data, &datalen) == 0 ) {
            ssh_protocol(NULL, 0, 1);
+            if (ssh_state == SSH_STATE_CLOSED) {
+                return;
+            }
+        }
     }
 }
 
@@ -781,7 +777,17 @@ static int do_ssh_login(unsigned char *in, int inlen, int ispkt)
        if (IS_SCP) {
            char prompt[200];
            sprintf(prompt, "%s@%s's password: ", cfg.username, savedhost);
-           ssh_get_password(prompt, password, sizeof(password));
+           if (!ssh_get_password(prompt, password, sizeof(password))) {
+                /*
+                 * get_password failed to get a password (for
+                 * example because one was supplied on the command
+                 * line which has already failed to work).
+                 * Terminate.
+                 */
+                logevent("No more passwords to try");
+                ssh_state = SSH_STATE_CLOSED;
+                crReturn(1);
+            }
        } else {
 
         if (pktin.type == SSH_SMSG_FAILURE &&
@@ -845,7 +851,8 @@ static int do_ssh_login(unsigned char *in, int inlen, int ispkt)
            logevent("Authentication refused");
        } else if (pktin.type == SSH_MSG_DISCONNECT) {
            logevent("Received disconnect request");
-           crReturn(0);
+            ssh_state = SSH_STATE_CLOSED;
+           crReturn(1);
        } else if (pktin.type != SSH_SMSG_SUCCESS) {
            fatalbox("Strange packet received, type %d", pktin.type);
        }
@@ -861,8 +868,11 @@ static void ssh_protocol(unsigned char *in, int inlen, int ispkt) {
 
     random_init();
 
-    while (!do_ssh_login(in, inlen, ispkt))
+    while (!do_ssh_login(in, inlen, ispkt)) {
        crReturnV;
+    }
+    if (ssh_state == SSH_STATE_CLOSED)
+        crReturnV;
 
     if (!cfg.nopty) {
        send_packet(SSH_CMSG_REQUEST_PTY,
@@ -982,6 +992,11 @@ static int ssh_msg (WPARAM wParam, LPARAM lParam) {
            return 0;
        }
        ssh_gotdata (buf, ret);
+        if (ssh_state == SSH_STATE_CLOSED) {
+            closesocket(s);
+            s = INVALID_SOCKET;
+            return 0;
+        }
        return 1;
     }
     return 1;                         /* shouldn't happen, but WTF */
@@ -1180,7 +1195,13 @@ char *ssh_scp_init(char *host, int port, char *cmd, char **realhost)
        get_packet();
        if (s == INVALID_SOCKET)
            return "Connection closed by remote host";
-    } while (!do_ssh_login(NULL, 0, 1)); 
+    } while (!do_ssh_login(NULL, 0, 1));
+
+    if (ssh_state == SSH_STATE_CLOSED) {
+        closesocket(s);
+        s = INVALID_SOCKET;
+        return "Session initialisation error";
+    }
 
     /* Execute command */
     sprintf(buf, "Sending command: %.100s", cmd);