Replace PuTTY's 2-3-4 tree implementation with the shiny new counted
[u/mdw/putty] / plink.c
diff --git a/plink.c b/plink.c
index bee734f..6eb9a78 100644 (file)
--- a/plink.c
+++ b/plink.c
@@ -7,16 +7,17 @@
 #endif
 #include <windows.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <stdarg.h>
 
 #define PUTTY_DO_GLOBALS                      /* actually _define_ globals */
 #include "putty.h"
-#include "winstuff.h"
 #include "storage.h"
+#include "tree234.h"
 
 void fatalbox (char *p, ...) {
     va_list ap;
-    fprintf(stderr, "FATAL ERROR: ", p);
+    fprintf(stderr, "FATAL ERROR: ");
     va_start(ap, p);
     vfprintf(stderr, p, ap);
     va_end(ap);
@@ -26,7 +27,7 @@ void fatalbox (char *p, ...) {
 }
 void connection_fatal (char *p, ...) {
     va_list ap;
-    fprintf(stderr, "FATAL ERROR: ", p);
+    fprintf(stderr, "FATAL ERROR: ");
     va_start(ap, p);
     vfprintf(stderr, p, ap);
     va_end(ap);
@@ -118,41 +119,53 @@ void verify_ssh_host_key(char *host, int port, char *keytype,
     }
 }
 
-HANDLE outhandle;
+HANDLE inhandle, outhandle, errhandle;
 DWORD orig_console_mode;
 
-void begin_session(void) {
-    if (!cfg.ldisc_term)
-        SetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), ENABLE_PROCESSED_INPUT);
-    else
-        SetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), orig_console_mode);
-}
+WSAEVENT netevent;
 
-void term_out(void)
-{
-    int reap;
+void from_backend(int is_stderr, char *data, int len) {
+    int pos;
     DWORD ret;
-    reap = 0;
-    while (reap < inbuf_head) {
-        if (!WriteFile(outhandle, inbuf+reap, inbuf_head-reap, &ret, NULL))
+    HANDLE h = (is_stderr ? errhandle : outhandle);
+
+    pos = 0;
+    while (pos < len) {
+        if (!WriteFile(h, data+pos, len-pos, &ret, NULL))
             return;                    /* give up in panic */
-        reap += ret;
+        pos += ret;
     }
-    inbuf_head = 0;
+}
+
+int term_ldisc(int mode) { return FALSE; }
+void ldisc_update(int echo, int edit) {
+    /* Update stdin read mode to reflect changes in line discipline. */
+    DWORD mode;
+
+    mode = ENABLE_PROCESSED_INPUT;
+    if (echo)
+        mode = mode | ENABLE_ECHO_INPUT;
+    else
+        mode = mode &~ ENABLE_ECHO_INPUT;
+    if (edit)
+        mode = mode | ENABLE_LINE_INPUT;
+    else
+        mode = mode &~ ENABLE_LINE_INPUT;
+    SetConsoleMode(inhandle, mode);
 }
 
 struct input_data {
     DWORD len;
     char buffer[4096];
-    HANDLE event;
+    HANDLE event, eventback;
 };
 
-static int get_password(const char *prompt, char *str, int maxlen)
+static int get_line(const char *prompt, char *str, int maxlen, int is_pw)
 {
     HANDLE hin, hout;
-    DWORD savemode, i;
+    DWORD savemode, newmode, i;
 
-    if (password) {
+    if (is_pw && password) {
         static int tried_once = 0;
 
         if (tried_once) {
@@ -173,8 +186,12 @@ static int get_password(const char *prompt, char *str, int maxlen)
     }
 
     GetConsoleMode(hin, &savemode);
-    SetConsoleMode(hin, (savemode & (~ENABLE_ECHO_INPUT)) |
-                   ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT);
+    newmode = savemode | ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT;
+    if (is_pw)
+        newmode &= ~ENABLE_ECHO_INPUT;
+    else
+        newmode |= ENABLE_ECHO_INPUT;
+    SetConsoleMode(hin, newmode);
 
     WriteFile(hout, prompt, strlen(prompt), &i, NULL);
     ReadFile(hin, str, maxlen-1, &i, NULL);
@@ -184,7 +201,8 @@ static int get_password(const char *prompt, char *str, int maxlen)
     if ((int)i > maxlen) i = maxlen-1; else i = i - 2;
     str[i] = '\0';
 
-    WriteFile(hout, "\r\n", 2, &i, NULL);
+    if (is_pw)
+        WriteFile(hout, "\r\n", 2, &i, NULL);
 
     return 1;
 }
@@ -196,8 +214,9 @@ static DWORD WINAPI stdin_read_thread(void *param) {
     inhandle = GetStdHandle(STD_INPUT_HANDLE);
 
     while (ReadFile(inhandle, idata->buffer, sizeof(idata->buffer),
-                    &idata->len, NULL)) {
+                    &idata->len, NULL) && idata->len > 0) {
         SetEvent(idata->event);
+        WaitForSingleObject(idata->eventback, INFINITE);
     }
 
     idata->len = 0;
@@ -219,21 +238,42 @@ static void usage(void)
     printf("  -ssh      force use of ssh protocol\n");
     printf("  -P port   connect to specified port\n");
     printf("  -pw passw login with specified password\n");
+    printf("  -m file   read remote command(s) from file\n");
     exit(1);
 }
 
+char *do_select(SOCKET skt, int startup) {
+    int events;
+    if (startup) {
+       events = FD_READ | FD_WRITE | FD_OOB | FD_CLOSE;
+    } else {
+       events = 0;
+    }
+    if (WSAEventSelect (skt, netevent, events) == SOCKET_ERROR) {
+        switch (WSAGetLastError()) {
+          case WSAENETDOWN: return "Network is down";
+          default: return "WSAAsyncSelect(): unknown error";
+        }
+    }
+    return NULL;
+}
+
 int main(int argc, char **argv) {
     WSADATA wsadata;
     WORD winsock_ver;
-    WSAEVENT netevent, stdinevent;
+    WSAEVENT stdinevent;
     HANDLE handles[2];
-    SOCKET socket;
     DWORD threadid;
     struct input_data idata;
     int sending;
     int portnumber = -1;
+    SOCKET *sklist;
+    int skcount, sksize;
+    int connopen;
+
+    ssh_get_line = get_line;
 
-    ssh_get_password = get_password;
+    sklist = NULL; skcount = sksize = 0;
 
     flags = FLAG_STDERR;
     /*
@@ -280,6 +320,35 @@ int main(int argc, char **argv) {
                 --argc, username = *++argv;
                 strncpy(cfg.username, username, sizeof(cfg.username));
                 cfg.username[sizeof(cfg.username)-1] = '\0';
+            } else if (!strcmp(p, "-m") && argc > 1) {
+                char *filename, *command;
+                int cmdlen, cmdsize;
+                FILE *fp;
+                int c, d;
+
+                --argc, filename = *++argv;
+
+                cmdlen = cmdsize = 0;
+                command = NULL;
+                fp = fopen(filename, "r");
+                if (!fp) {
+                    fprintf(stderr, "plink: unable to open command "
+                            "file \"%s\"\n", filename);
+                    return 1;
+                }
+                do {
+                    c = fgetc(fp);
+                    d = c;
+                    if (c == EOF)
+                        d = 0;
+                    if (cmdlen >= cmdsize) {
+                        cmdsize = cmdlen + 512;
+                        command = srealloc(command, cmdsize);
+                    }
+                    command[cmdlen++] = d;
+                } while (c != EOF);
+                cfg.remote_cmd_ptr = command;
+                cfg.nopty = TRUE;      /* command => no terminal */
             } else if (!strcmp(p, "-P") && argc > 1) {
                 --argc, portnumber = atoi(*++argv);
             }
@@ -347,12 +416,16 @@ int main(int argc, char **argv) {
                         /*
                          * One string.
                          */
-                        do_defaults (p, &cfg);
-                        if (cfg.host[0] == '\0') {
+                        Config cfg2;
+                        do_defaults (p, &cfg2);
+                        if (cfg2.host[0] == '\0') {
                             /* No settings for this host; use defaults */
                             strncpy(cfg.host, p, sizeof(cfg.host)-1);
                             cfg.host[sizeof(cfg.host)-1] = '\0';
-                            cfg.port = 22;
+                            cfg.port = default_port;
+                        } else {
+                            cfg = cfg2;
+                            cfg.remote_cmd_ptr = cfg.remote_cmd;
                         }
                     } else {
                         *r++ = '\0';
@@ -360,7 +433,7 @@ int main(int argc, char **argv) {
                         cfg.username[sizeof(cfg.username)-1] = '\0';
                         strncpy(cfg.host, r, sizeof(cfg.host)-1);
                         cfg.host[sizeof(cfg.host)-1] = '\0';
-                        cfg.port = 22;
+                        cfg.port = default_port;
                     }
                 }
             } else {
@@ -377,7 +450,6 @@ int main(int argc, char **argv) {
                     len2 = strlen(cp); len -= len2; cp += len2;
                 }
                 cfg.nopty = TRUE;      /* command => no terminal */
-                cfg.ldisc_term = TRUE; /* use stdin like a line buffer */
                 break;                 /* done with cmdline */
             }
        }
@@ -387,7 +459,7 @@ int main(int argc, char **argv) {
         usage();
     }
 
-    if (!*cfg.remote_cmd)
+    if (!*cfg.remote_cmd_ptr)
         flags |= FLAG_INTERACTIVE;
 
     /*
@@ -429,38 +501,37 @@ int main(int argc, char **argv) {
        WSACleanup();
        return 1;
     }
+    sk_init();
 
     /*
      * Start up the connection.
      */
+    netevent = CreateEvent(NULL, FALSE, FALSE, NULL);
     {
        char *error;
        char *realhost;
 
-       error = back->init (NULL, cfg.host, cfg.port, &realhost);
+       error = back->init (cfg.host, cfg.port, &realhost);
        if (error) {
            fprintf(stderr, "Unable to open connection:\n%s", error);
            return 1;
        }
     }
+    connopen = 1;
 
-    netevent = CreateEvent(NULL, FALSE, FALSE, NULL);
     stdinevent = CreateEvent(NULL, FALSE, FALSE, NULL);
 
-    GetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), &orig_console_mode);
-    SetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), ENABLE_PROCESSED_INPUT);
+    inhandle = GetStdHandle(STD_INPUT_HANDLE);
     outhandle = GetStdHandle(STD_OUTPUT_HANDLE);
+    errhandle = GetStdHandle(STD_ERROR_HANDLE);
+    GetConsoleMode(inhandle, &orig_console_mode);
+    SetConsoleMode(inhandle, ENABLE_PROCESSED_INPUT);
 
     /*
-     * Now we must send the back end oodles of stuff.
-     */
-    socket = back->socket();
-    /*
      * Turn off ECHO and LINE input modes. We don't care if this
      * call fails, because we know we aren't necessarily running in
      * a console.
      */
-    WSAEventSelect(socket, netevent, FD_READ | FD_CLOSE);
     handles[0] = netevent;
     handles[1] = stdinevent;
     sending = FALSE;
@@ -485,6 +556,7 @@ int main(int argc, char **argv) {
              *    - so we're back to ReadFile blocking.
              */
             idata.event = stdinevent;
+            idata.eventback = CreateEvent(NULL, FALSE, FALSE, NULL);
             if (!CreateThread(NULL, 0, stdin_read_thread,
                               &idata, 0, &threadid)) {
                 fprintf(stderr, "Unable to create second thread\n");
@@ -496,23 +568,64 @@ int main(int argc, char **argv) {
         n = WaitForMultipleObjects(2, handles, FALSE, INFINITE);
         if (n == 0) {
             WSANETWORKEVENTS things;
-            if (!WSAEnumNetworkEvents(socket, netevent, &things)) {
-                if (things.lNetworkEvents & FD_READ)
-                    back->msg(0, FD_READ);
-                if (things.lNetworkEvents & FD_CLOSE) {
-                    back->msg(0, FD_CLOSE);
-                    break;
-                }
+           SOCKET socket;
+           extern SOCKET first_socket(int *), next_socket(int *);
+           extern int select_result(WPARAM, LPARAM);
+            int i, socketstate;
+
+            /*
+             * We must not call select_result() for any socket
+             * until we have finished enumerating within the tree.
+             * This is because select_result() may close the socket
+             * and modify the tree.
+             */
+            /* Count the active sockets. */
+            i = 0;
+            for (socket = first_socket(&socketstate); socket != INVALID_SOCKET;
+                socket = next_socket(&socketstate))
+                i++;
+
+            /* Expand the buffer if necessary. */
+            if (i > sksize) {
+                sksize = i+16;
+                sklist = srealloc(sklist, sksize * sizeof(*sklist));
+            }
+
+            /* Retrieve the sockets into sklist. */
+            skcount = 0;
+           for (socket = first_socket(&socketstate); socket != INVALID_SOCKET;
+                socket = next_socket(&socketstate)) {
+                sklist[skcount++] = socket;
             }
-            term_out();
+
+            /* Now we're done enumerating; go through the list. */
+            for (i = 0; i < skcount; i++) {
+                WPARAM wp;
+                socket = sklist[i];
+                wp = (WPARAM)socket;
+               if (!WSAEnumNetworkEvents(socket, NULL, &things)) {
+                    noise_ultralight(socket);
+                    noise_ultralight(things.lNetworkEvents);
+                   if (things.lNetworkEvents & FD_READ)
+                       connopen &= select_result(wp, (LPARAM)FD_READ);
+                   if (things.lNetworkEvents & FD_CLOSE)
+                       connopen &= select_result(wp, (LPARAM)FD_CLOSE);
+                   if (things.lNetworkEvents & FD_OOB)
+                       connopen &= select_result(wp, (LPARAM)FD_OOB);
+                   if (things.lNetworkEvents & FD_WRITE)
+                        connopen &= select_result(wp, (LPARAM)FD_WRITE);
+               }
+           }
         } else if (n == 1) {
+            noise_ultralight(idata.len);
             if (idata.len > 0) {
                 back->send(idata.buffer, idata.len);
             } else {
                 back->special(TS_EOF);
             }
+            SetEvent(idata.eventback);
         }
-        if (back->socket() == INVALID_SOCKET)
+        if (!connopen || back->socket() == NULL)
             break;                 /* we closed the connection */
     }
     WSACleanup();