Rethink the whole line discipline architecture. Instead of having
[u/mdw/putty] / plink.c
diff --git a/plink.c b/plink.c
index 938bcb7..6cbc2f6 100644 (file)
--- a/plink.c
+++ b/plink.c
@@ -7,15 +7,17 @@
 #endif
 #include <windows.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <stdarg.h>
 
 #define PUTTY_DO_GLOBALS                      /* actually _define_ globals */
 #include "putty.h"
 #include "storage.h"
+#include "tree234.h"
 
 void fatalbox (char *p, ...) {
     va_list ap;
-    fprintf(stderr, "FATAL ERROR: ", p);
+    fprintf(stderr, "FATAL ERROR: ");
     va_start(ap, p);
     vfprintf(stderr, p, ap);
     va_end(ap);
@@ -25,7 +27,7 @@ void fatalbox (char *p, ...) {
 }
 void connection_fatal (char *p, ...) {
     va_list ap;
-    fprintf(stderr, "FATAL ERROR: ", p);
+    fprintf(stderr, "FATAL ERROR: ");
     va_start(ap, p);
     vfprintf(stderr, p, ap);
     va_end(ap);
@@ -41,6 +43,8 @@ void logevent(char *string) { }
 void verify_ssh_host_key(char *host, int port, char *keytype,
                          char *keystr, char *fingerprint) {
     int ret;
+    HANDLE hin;
+    DWORD savemode, i;
 
     static const char absentmsg[] =
         "The server's host key is not cached in the registry. You\n"
@@ -83,10 +87,21 @@ void verify_ssh_host_key(char *host, int port, char *keytype,
 
     if (ret == 0)                      /* success - key matched OK */
         return;
-    if (ret == 2) {                    /* key was different */
+
+    if (ret == 2)                      /* key was different */
         fprintf(stderr, wrongmsg, fingerprint);
-        if (fgets(line, sizeof(line), stdin) &&
-            line[0] != '\0' && line[0] != '\n') {
+    if (ret == 1)                      /* key was absent */
+        fprintf(stderr, absentmsg, fingerprint);
+
+    hin = GetStdHandle(STD_INPUT_HANDLE);
+    GetConsoleMode(hin, &savemode);
+    SetConsoleMode(hin, (savemode | ENABLE_ECHO_INPUT |
+                         ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT));
+    ReadFile(hin, line, sizeof(line)-1, &i, NULL);
+    SetConsoleMode(hin, savemode);
+
+    if (ret == 2) {                    /* key was different */
+        if (line[0] != '\0' && line[0] != '\r' && line[0] != '\n') {
             if (line[0] == 'y' || line[0] == 'Y')
                 store_host_key(host, port, keytype, keystr);
         } else {
@@ -95,9 +110,7 @@ void verify_ssh_host_key(char *host, int port, char *keytype,
         }
     }
     if (ret == 1) {                    /* key was absent */
-        fprintf(stderr, absentmsg, fingerprint);
-        if (fgets(line, sizeof(line), stdin) &&
-            (line[0] == 'y' || line[0] == 'Y'))
+        if (line[0] == 'y' || line[0] == 'Y')
             store_host_key(host, port, keytype, keystr);
         else {
             fprintf(stderr, abandoned);
@@ -106,33 +119,45 @@ void verify_ssh_host_key(char *host, int port, char *keytype,
     }
 }
 
-HANDLE outhandle;
+HANDLE inhandle, outhandle, errhandle;
 DWORD orig_console_mode;
 
-void begin_session(void) {
-    if (!cfg.ldisc_term)
-        SetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), ENABLE_PROCESSED_INPUT);
-    else
-        SetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), orig_console_mode);
-}
+WSAEVENT netevent;
 
-void term_out(void)
-{
-    int reap;
+void from_backend(int is_stderr, char *data, int len) {
+    int pos;
     DWORD ret;
-    reap = 0;
-    while (reap < inbuf_head) {
-        if (!WriteFile(outhandle, inbuf+reap, inbuf_head-reap, &ret, NULL))
+    HANDLE h = (is_stderr ? errhandle : outhandle);
+
+    pos = 0;
+    while (pos < len) {
+        if (!WriteFile(h, data+pos, len-pos, &ret, NULL))
             return;                    /* give up in panic */
-        reap += ret;
+        pos += ret;
     }
-    inbuf_head = 0;
+}
+
+int term_ldisc(int mode) { return FALSE; }
+void ldisc_update(int echo, int edit) {
+    /* Update stdin read mode to reflect changes in line discipline. */
+    DWORD mode;
+
+    mode = ENABLE_PROCESSED_INPUT;
+    if (echo)
+        mode = mode | ENABLE_ECHO_INPUT;
+    else
+        mode = mode &~ ENABLE_ECHO_INPUT;
+    if (edit)
+        mode = mode | ENABLE_LINE_INPUT;
+    else
+        mode = mode &~ ENABLE_LINE_INPUT;
+    SetConsoleMode(inhandle, mode);
 }
 
 struct input_data {
     DWORD len;
     char buffer[4096];
-    HANDLE event;
+    HANDLE event, eventback;
 };
 
 static int get_password(const char *prompt, char *str, int maxlen)
@@ -184,8 +209,9 @@ static DWORD WINAPI stdin_read_thread(void *param) {
     inhandle = GetStdHandle(STD_INPUT_HANDLE);
 
     while (ReadFile(inhandle, idata->buffer, sizeof(idata->buffer),
-                    &idata->len, NULL)) {
+                    &idata->len, NULL) && idata->len > 0) {
         SetEvent(idata->event);
+        WaitForSingleObject(idata->eventback, INFINITE);
     }
 
     idata->len = 0;
@@ -210,19 +236,39 @@ static void usage(void)
     exit(1);
 }
 
+char *do_select(SOCKET skt, int startup) {
+    int events;
+    if (startup) {
+       events = FD_READ | FD_WRITE | FD_OOB | FD_CLOSE;
+    } else {
+       events = 0;
+    }
+    if (WSAEventSelect (skt, netevent, events) == SOCKET_ERROR) {
+        switch (WSAGetLastError()) {
+          case WSAENETDOWN: return "Network is down";
+          default: return "WSAAsyncSelect(): unknown error";
+        }
+    }
+    return NULL;
+}
+
 int main(int argc, char **argv) {
     WSADATA wsadata;
     WORD winsock_ver;
-    WSAEVENT netevent, stdinevent;
+    WSAEVENT stdinevent;
     HANDLE handles[2];
-    SOCKET socket;
     DWORD threadid;
     struct input_data idata;
     int sending;
     int portnumber = -1;
+    SOCKET *sklist;
+    int skcount, sksize;
+    int connopen;
 
     ssh_get_password = get_password;
 
+    sklist = NULL; skcount = sksize = 0;
+
     flags = FLAG_STDERR;
     /*
      * Process the command line.
@@ -335,13 +381,15 @@ int main(int argc, char **argv) {
                         /*
                          * One string.
                          */
-                        do_defaults (p, &cfg);
-                        if (cfg.host[0] == '\0') {
+                        Config cfg2;
+                        do_defaults (p, &cfg2);
+                        if (cfg2.host[0] == '\0') {
                             /* No settings for this host; use defaults */
                             strncpy(cfg.host, p, sizeof(cfg.host)-1);
                             cfg.host[sizeof(cfg.host)-1] = '\0';
                             cfg.port = 22;
-                        }
+                        } else
+                            cfg = cfg2;
                     } else {
                         *r++ = '\0';
                         strncpy(cfg.username, p, sizeof(cfg.username)-1);
@@ -365,7 +413,6 @@ int main(int argc, char **argv) {
                     len2 = strlen(cp); len -= len2; cp += len2;
                 }
                 cfg.nopty = TRUE;      /* command => no terminal */
-                cfg.ldisc_term = TRUE; /* use stdin like a line buffer */
                 break;                 /* done with cmdline */
             }
        }
@@ -417,38 +464,37 @@ int main(int argc, char **argv) {
        WSACleanup();
        return 1;
     }
+    sk_init();
 
     /*
      * Start up the connection.
      */
+    netevent = CreateEvent(NULL, FALSE, FALSE, NULL);
     {
        char *error;
        char *realhost;
 
-       error = back->init (NULL, cfg.host, cfg.port, &realhost);
+       error = back->init (cfg.host, cfg.port, &realhost);
        if (error) {
            fprintf(stderr, "Unable to open connection:\n%s", error);
            return 1;
        }
     }
+    connopen = 1;
 
-    netevent = CreateEvent(NULL, FALSE, FALSE, NULL);
     stdinevent = CreateEvent(NULL, FALSE, FALSE, NULL);
 
-    GetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), &orig_console_mode);
-    SetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), ENABLE_PROCESSED_INPUT);
+    inhandle = GetStdHandle(STD_INPUT_HANDLE);
     outhandle = GetStdHandle(STD_OUTPUT_HANDLE);
+    errhandle = GetStdHandle(STD_ERROR_HANDLE);
+    GetConsoleMode(inhandle, &orig_console_mode);
+    SetConsoleMode(inhandle, ENABLE_PROCESSED_INPUT);
 
     /*
-     * Now we must send the back end oodles of stuff.
-     */
-    socket = back->socket();
-    /*
      * Turn off ECHO and LINE input modes. We don't care if this
      * call fails, because we know we aren't necessarily running in
      * a console.
      */
-    WSAEventSelect(socket, netevent, FD_READ | FD_CLOSE);
     handles[0] = netevent;
     handles[1] = stdinevent;
     sending = FALSE;
@@ -473,6 +519,7 @@ int main(int argc, char **argv) {
              *    - so we're back to ReadFile blocking.
              */
             idata.event = stdinevent;
+            idata.eventback = CreateEvent(NULL, FALSE, FALSE, NULL);
             if (!CreateThread(NULL, 0, stdin_read_thread,
                               &idata, 0, &threadid)) {
                 fprintf(stderr, "Unable to create second thread\n");
@@ -484,22 +531,66 @@ int main(int argc, char **argv) {
         n = WaitForMultipleObjects(2, handles, FALSE, INFINITE);
         if (n == 0) {
             WSANETWORKEVENTS things;
-            if (!WSAEnumNetworkEvents(socket, netevent, &things)) {
-                if (things.lNetworkEvents & FD_READ)
-                    back->msg(0, FD_READ);
-                if (things.lNetworkEvents & FD_CLOSE) {
-                    back->msg(0, FD_CLOSE);
-                    break;
-                }
+           enum234 e;
+           SOCKET socket;
+           extern SOCKET first_socket(enum234 *), next_socket(enum234 *);
+           extern int select_result(WPARAM, LPARAM);
+            int i;
+
+            /*
+             * We must not call select_result() for any socket
+             * until we have finished enumerating within the tree.
+             * This is because select_result() may close the socket
+             * and modify the tree.
+             */
+            /* Count the active sockets. */
+            i = 0;
+            for (socket = first_socket(&e); socket != INVALID_SOCKET;
+                socket = next_socket(&e))
+                i++;
+
+            /* Expand the buffer if necessary. */
+            if (i > sksize) {
+                sksize = i+16;
+                sklist = srealloc(sklist, sksize * sizeof(*sklist));
             }
-            term_out();
+
+            /* Retrieve the sockets into sklist. */
+            skcount = 0;
+           for (socket = first_socket(&e); socket != INVALID_SOCKET;
+                socket = next_socket(&e)) {
+                sklist[skcount++] = socket;
+            }
+
+            /* Now we're done enumerating; go through the list. */
+            for (i = 0; i < skcount; i++) {
+                WPARAM wp;
+                socket = sklist[i];
+                wp = (WPARAM)socket;
+               if (!WSAEnumNetworkEvents(socket, NULL, &things)) {
+                    noise_ultralight(socket);
+                    noise_ultralight(things.lNetworkEvents);
+                   if (things.lNetworkEvents & FD_READ)
+                       connopen &= select_result(wp, (LPARAM)FD_READ);
+                   if (things.lNetworkEvents & FD_CLOSE)
+                       connopen &= select_result(wp, (LPARAM)FD_CLOSE);
+                   if (things.lNetworkEvents & FD_OOB)
+                       connopen &= select_result(wp, (LPARAM)FD_OOB);
+                   if (things.lNetworkEvents & FD_WRITE)
+                        connopen &= select_result(wp, (LPARAM)FD_WRITE);
+               }
+           }
         } else if (n == 1) {
+            noise_ultralight(idata.len);
             if (idata.len > 0) {
                 back->send(idata.buffer, idata.len);
             } else {
                 back->special(TS_EOF);
             }
+            SetEvent(idata.eventback);
         }
+        if (!connopen || back->socket() == NULL)
+            break;                 /* we closed the connection */
     }
     WSACleanup();
     return 0;