Add a new back-end function to return the exit code of the remote
[u/mdw/putty] / plink.c
diff --git a/plink.c b/plink.c
index 4113a81..f8a9f7c 100644 (file)
--- a/plink.c
+++ b/plink.c
@@ -15,6 +15,8 @@
 #include "storage.h"
 #include "tree234.h"
 
+#define MAX_STDIN_BACKLOG 4096
+
 void fatalbox(char *p, ...)
 {
     va_list ap;
@@ -120,25 +122,108 @@ void verify_ssh_host_key(char *host, int port, char *keytype,
     }
 }
 
-HANDLE inhandle, outhandle, errhandle;
-DWORD orig_console_mode;
+/*
+ * Ask whether the selected cipher is acceptable (since it was
+ * below the configured 'warn' threshold).
+ * cs: 0 = both ways, 1 = client->server, 2 = server->client
+ */
+void askcipher(char *ciphername, int cs)
+{
+    HANDLE hin;
+    DWORD savemode, i;
 
-WSAEVENT netevent;
+    static const char msg[] =
+       "The first %scipher supported by the server is\n"
+       "%s, which is below the configured warning threshold.\n"
+       "Continue with connection? (y/n) ";
+    static const char abandoned[] = "Connection abandoned.\n";
 
-void from_backend(int is_stderr, char *data, int len)
-{
-    int pos;
-    DWORD ret;
-    HANDLE h = (is_stderr ? errhandle : outhandle);
+    char line[32];
+
+    fprintf(stderr, msg,
+           (cs == 0) ? "" :
+           (cs == 1) ? "client-to-server " :
+                       "server-to-client ",
+           ciphername);
+    fflush(stderr);
+
+    hin = GetStdHandle(STD_INPUT_HANDLE);
+    GetConsoleMode(hin, &savemode);
+    SetConsoleMode(hin, (savemode | ENABLE_ECHO_INPUT |
+                        ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT));
+    ReadFile(hin, line, sizeof(line) - 1, &i, NULL);
+    SetConsoleMode(hin, savemode);
 
-    pos = 0;
-    while (pos < len) {
-       if (!WriteFile(h, data + pos, len - pos, &ret, NULL))
-           return;                    /* give up in panic */
-       pos += ret;
+    if (line[0] == 'y' || line[0] == 'Y') {
+       return;
+    } else {
+       fprintf(stderr, abandoned);
+       exit(0);
     }
 }
 
+/*
+ * Ask whether to wipe a session log file before writing to it.
+ * Returns 2 for wipe, 1 for append, 0 for cancel (don't log).
+ */
+int askappend(char *filename)
+{
+    HANDLE hin;
+    DWORD savemode, i;
+
+    static const char msgtemplate[] =
+       "The session log file \"%.*s\" already exists.\n"
+       "You can overwrite it with a new session log,\n"
+       "append your session log to the end of it,\n"
+       "or disable session logging for this session.\n"
+       "Enter \"y\" to wipe the file, \"n\" to append to it,\n"
+       "or just press Return to disable logging.\n"
+       "Wipe the log file? (y/n, Return cancels logging) ";
+
+    char line[32];
+
+    fprintf(stderr, msgtemplate, FILENAME_MAX, filename);
+    fflush(stderr);
+
+    hin = GetStdHandle(STD_INPUT_HANDLE);
+    GetConsoleMode(hin, &savemode);
+    SetConsoleMode(hin, (savemode | ENABLE_ECHO_INPUT |
+                        ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT));
+    ReadFile(hin, line, sizeof(line) - 1, &i, NULL);
+    SetConsoleMode(hin, savemode);
+
+    if (line[0] == 'y' || line[0] == 'Y')
+       return 2;
+    else if (line[0] == 'n' || line[0] == 'N')
+       return 1;
+    else
+       return 0;
+}
+
+/*
+ * Warn about the obsolescent key file format.
+ */
+void old_keyfile_warning(void)
+{
+    static const char message[] =
+       "You are loading an SSH 2 private key which has an\n"
+       "old version of the file format. This means your key\n"
+       "file is not fully tamperproof. Future versions of\n"
+       "PuTTY may stop supporting this private key format,\n"
+       "so we recommend you convert your key to the new\n"
+       "format.\n"
+       "\n"
+       "Once the key is loaded into PuTTYgen, you can perform\n"
+       "this conversion simply by saving it again.\n";
+
+    fputs(message, stderr);
+}
+
+HANDLE inhandle, outhandle, errhandle;
+DWORD orig_console_mode;
+
+WSAEVENT netevent;
+
 int term_ldisc(int mode)
 {
     return FALSE;
@@ -160,12 +245,6 @@ void ldisc_update(int echo, int edit)
     SetConsoleMode(inhandle, mode);
 }
 
-struct input_data {
-    DWORD len;
-    char buffer[4096];
-    HANDLE event, eventback;
-};
-
 static int get_line(const char *prompt, char *str, int maxlen, int is_pw)
 {
     HANDLE hin, hout;
@@ -216,6 +295,12 @@ static int get_line(const char *prompt, char *str, int maxlen, int is_pw)
     return 1;
 }
 
+struct input_data {
+    DWORD len;
+    char buffer[4096];
+    HANDLE event, eventback;
+};
+
 static DWORD WINAPI stdin_read_thread(void *param)
 {
     struct input_data *idata = (struct input_data *) param;
@@ -235,6 +320,74 @@ static DWORD WINAPI stdin_read_thread(void *param)
     return 0;
 }
 
+struct output_data {
+    DWORD len, lenwritten;
+    int writeret;
+    char *buffer;
+    int is_stderr, done;
+    HANDLE event, eventback;
+    int busy;
+};
+
+static DWORD WINAPI stdout_write_thread(void *param)
+{
+    struct output_data *odata = (struct output_data *) param;
+    HANDLE outhandle, errhandle;
+
+    outhandle = GetStdHandle(STD_OUTPUT_HANDLE);
+    errhandle = GetStdHandle(STD_ERROR_HANDLE);
+
+    while (1) {
+       WaitForSingleObject(odata->eventback, INFINITE);
+       if (odata->done)
+           break;
+       odata->writeret =
+           WriteFile(odata->is_stderr ? errhandle : outhandle,
+                     odata->buffer, odata->len, &odata->lenwritten, NULL);
+       SetEvent(odata->event);
+    }
+
+    return 0;
+}
+
+bufchain stdout_data, stderr_data;
+struct output_data odata, edata;
+
+void try_output(int is_stderr)
+{
+    struct output_data *data = (is_stderr ? &edata : &odata);
+    void *senddata;
+    int sendlen;
+
+    if (!data->busy) {
+       bufchain_prefix(is_stderr ? &stderr_data : &stdout_data,
+                       &senddata, &sendlen);
+       data->buffer = senddata;
+       data->len = sendlen;
+       SetEvent(data->eventback);
+       data->busy = 1;
+    }
+}
+
+int from_backend(int is_stderr, char *data, int len)
+{
+    HANDLE h = (is_stderr ? errhandle : outhandle);
+    int osize, esize;
+
+    if (is_stderr) {
+       bufchain_add(&stderr_data, data, len);
+       try_output(1);
+    } else {
+       bufchain_add(&stdout_data, data, len);
+       try_output(0);
+    }
+
+    osize = bufchain_size(&stdout_data);
+    esize = bufchain_size(&stderr_data);
+
+    return osize + esize;
+}
+
 /*
  *  Short description of parameters.
  */
@@ -243,12 +396,17 @@ static void usage(void)
     printf("PuTTY Link: command-line connection utility\n");
     printf("%s\n", ver);
     printf("Usage: plink [options] [user@]host [command]\n");
+    printf("       (\"host\" can also be a PuTTY saved session name)\n");
     printf("Options:\n");
     printf("  -v        show verbose messages\n");
     printf("  -ssh      force use of ssh protocol\n");
     printf("  -P port   connect to specified port\n");
     printf("  -pw passw login with specified password\n");
     printf("  -m file   read remote command(s) from file\n");
+    printf("  -L listen-port:host:port   Forward local port to "
+          "remote address\n");
+    printf("  -R listen-port:host:port   Forward remote port to"
+          " local address\n");
     exit(1);
 }
 
@@ -256,7 +414,8 @@ char *do_select(SOCKET skt, int startup)
 {
     int events;
     if (startup) {
-       events = FD_READ | FD_WRITE | FD_OOB | FD_CLOSE | FD_ACCEPT;
+       events = (FD_CONNECT | FD_READ | FD_WRITE |
+                 FD_OOB | FD_CLOSE | FD_ACCEPT);
     } else {
        events = 0;
     }
@@ -275,15 +434,18 @@ int main(int argc, char **argv)
 {
     WSADATA wsadata;
     WORD winsock_ver;
-    WSAEVENT stdinevent;
-    HANDLE handles[2];
-    DWORD threadid;
+    WSAEVENT stdinevent, stdoutevent, stderrevent;
+    HANDLE handles[4];
+    DWORD in_threadid, out_threadid, err_threadid;
     struct input_data idata;
+    int reading;
     int sending;
     int portnumber = -1;
     SOCKET *sklist;
     int skcount, sksize;
     int connopen;
+    int exitcode;
+    char extra_portfwd[sizeof(cfg.portfwd)];
 
     ssh_get_line = get_line;
 
@@ -329,6 +491,9 @@ int main(int argc, char **argv)
            } else if (!strcmp(p, "-telnet")) {
                default_protocol = cfg.protocol = PROT_TELNET;
                default_port = cfg.port = 23;
+           } else if (!strcmp(p, "-rlogin")) {
+               default_protocol = cfg.protocol = PROT_RLOGIN;
+               default_port = cfg.port = 513;
            } else if (!strcmp(p, "-raw")) {
                default_protocol = cfg.protocol = PROT_RAW;
            } else if (!strcmp(p, "-v")) {
@@ -342,6 +507,24 @@ int main(int argc, char **argv)
                --argc, username = *++argv;
                strncpy(cfg.username, username, sizeof(cfg.username));
                cfg.username[sizeof(cfg.username) - 1] = '\0';
+           } else if ((!strcmp(p, "-L") || !strcmp(p, "-R")) && argc > 1) {
+               char *fwd, *ptr, *q;
+               int i=0;
+               --argc, fwd = *++argv;
+               ptr = extra_portfwd;
+               /* if multiple forwards, find end of list */
+               if (ptr[0]=='R' || ptr[0]=='L') {
+                   for (i = 0; i < sizeof(extra_portfwd) - 2; i++)
+                       if (ptr[i]=='\000' && ptr[i+1]=='\000')
+                           break;
+                   ptr = ptr + i + 1;  /* point to next forward slot */
+               }
+               ptr[0] = p[1];  /* insert a 'L' or 'R' at the start */
+               strncpy(ptr+1, fwd, sizeof(extra_portfwd) - i);
+               q = strchr(ptr, ':');
+               if (q) *q = '\t';      /* replace first : with \t */
+               ptr[strlen(ptr)+1] = '\000';    /* append two '\000' */
+               extra_portfwd[sizeof(extra_portfwd) - 1] = '\0';
            } else if (!strcmp(p, "-m") && argc > 1) {
                char *filename, *command;
                int cmdlen, cmdsize;
@@ -370,6 +553,7 @@ int main(int argc, char **argv)
                    command[cmdlen++] = d;
                } while (c != EOF);
                cfg.remote_cmd_ptr = command;
+               cfg.remote_cmd_ptr2 = NULL;
                cfg.nopty = TRUE;      /* command => no terminal */
            } else if (!strcmp(p, "-P") && argc > 1) {
                --argc, portnumber = atoi(*++argv);
@@ -491,6 +675,32 @@ int main(int argc, char **argv)
        usage();
     }
 
+    /*
+     * Trim leading whitespace off the hostname if it's there.
+     */
+    {
+       int space = strspn(cfg.host, " \t");
+       memmove(cfg.host, cfg.host+space, 1+strlen(cfg.host)-space);
+    }
+
+    /* See if host is of the form user@host */
+    if (cfg.host[0] != '\0') {
+       char *atsign = strchr(cfg.host, '@');
+       /* Make sure we're not overflowing the user field */
+       if (atsign) {
+           if (atsign - cfg.host < sizeof cfg.username) {
+               strncpy(cfg.username, cfg.host, atsign - cfg.host);
+               cfg.username[atsign - cfg.host] = '\0';
+           }
+           memmove(cfg.host, atsign + 1, 1 + strlen(atsign + 1));
+       }
+    }
+
+    /*
+     * Trim a colon suffix off the hostname if it's there.
+     */
+    cfg.host[strcspn(cfg.host, ":")] = '\0';
+
     if (!*cfg.remote_cmd_ptr)
        flags |= FLAG_INTERACTIVE;
 
@@ -514,6 +724,30 @@ int main(int argc, char **argv)
     }
 
     /*
+     * Add extra port forwardings (accumulated on command line) to
+     * cfg.
+     */
+    {
+       int i;
+       char *p;
+       p = extra_portfwd;
+       i = 0;
+       while (cfg.portfwd[i])
+           i += strlen(cfg.portfwd+i) + 1;
+       while (*p) {
+           if (strlen(p)+2 > sizeof(cfg.portfwd)-i) {
+               fprintf(stderr, "Internal fault: not enough space for all"
+                       " port forwardings\n");
+               break;
+           }
+           strncpy(cfg.portfwd+i, p, sizeof(cfg.portfwd)-i-1);
+           i += strlen(cfg.portfwd+i) + 1;
+           cfg.portfwd[i] = '\0';
+           p += strlen(p)+1;
+       }
+    }
+
+    /*
      * Select port.
      */
     if (portnumber != -1)
@@ -543,8 +777,11 @@ int main(int argc, char **argv)
     {
        char *error;
        char *realhost;
+       /* nodelay is only useful if stdin is a character device (console) */
+       int nodelay = cfg.tcp_nodelay &&
+           (GetFileType(GetStdHandle(STD_INPUT_HANDLE)) == FILE_TYPE_CHAR);
 
-       error = back->init(cfg.host, cfg.port, &realhost);
+       error = back->init(cfg.host, cfg.port, &realhost, nodelay);
        if (error) {
            fprintf(stderr, "Unable to open connection:\n%s", error);
            return 1;
@@ -554,6 +791,8 @@ int main(int argc, char **argv)
     connopen = 1;
 
     stdinevent = CreateEvent(NULL, FALSE, FALSE, NULL);
+    stdoutevent = CreateEvent(NULL, FALSE, FALSE, NULL);
+    stderrevent = CreateEvent(NULL, FALSE, FALSE, NULL);
 
     inhandle = GetStdHandle(STD_INPUT_HANDLE);
     outhandle = GetStdHandle(STD_OUTPUT_HANDLE);
@@ -568,7 +807,33 @@ int main(int argc, char **argv)
      */
     handles[0] = netevent;
     handles[1] = stdinevent;
+    handles[2] = stdoutevent;
+    handles[3] = stderrevent;
     sending = FALSE;
+
+    /*
+     * Create spare threads to write to stdout and stderr, so we
+     * can arrange asynchronous writes.
+     */
+    odata.event = stdoutevent;
+    odata.eventback = CreateEvent(NULL, FALSE, FALSE, NULL);
+    odata.is_stderr = 0;
+    odata.busy = odata.done = 0;
+    if (!CreateThread(NULL, 0, stdout_write_thread,
+                     &odata, 0, &out_threadid)) {
+       fprintf(stderr, "Unable to create output thread\n");
+       exit(1);
+    }
+    edata.event = stderrevent;
+    edata.eventback = CreateEvent(NULL, FALSE, FALSE, NULL);
+    edata.is_stderr = 1;
+    edata.busy = edata.done = 0;
+    if (!CreateThread(NULL, 0, stdout_write_thread,
+                     &edata, 0, &err_threadid)) {
+       fprintf(stderr, "Unable to create error output thread\n");
+       exit(1);
+    }
+
     while (1) {
        int n;
 
@@ -592,14 +857,14 @@ int main(int argc, char **argv)
            idata.event = stdinevent;
            idata.eventback = CreateEvent(NULL, FALSE, FALSE, NULL);
            if (!CreateThread(NULL, 0, stdin_read_thread,
-                             &idata, 0, &threadid)) {
-               fprintf(stderr, "Unable to create second thread\n");
+                             &idata, 0, &in_threadid)) {
+               fprintf(stderr, "Unable to create input thread\n");
                exit(1);
            }
            sending = TRUE;
        }
 
-       n = WaitForMultipleObjects(2, handles, FALSE, INFINITE);
+       n = WaitForMultipleObjects(4, handles, FALSE, INFINITE);
        if (n == 0) {
            WSANETWORKEVENTS things;
            SOCKET socket;
@@ -641,6 +906,8 @@ int main(int argc, char **argv)
                if (!WSAEnumNetworkEvents(socket, NULL, &things)) {
                    noise_ultralight(socket);
                    noise_ultralight(things.lNetworkEvents);
+                   if (things.lNetworkEvents & FD_CONNECT)
+                       connopen &= select_result(wp, (LPARAM) FD_CONNECT);
                    if (things.lNetworkEvents & FD_READ)
                        connopen &= select_result(wp, (LPARAM) FD_READ);
                    if (things.lNetworkEvents & FD_CLOSE)
@@ -655,17 +922,56 @@ int main(int argc, char **argv)
                }
            }
        } else if (n == 1) {
+           reading = 0;
            noise_ultralight(idata.len);
-           if (idata.len > 0) {
-               back->send(idata.buffer, idata.len);
-           } else {
-               back->special(TS_EOF);
+           if (connopen && back->socket() != NULL) {
+               if (idata.len > 0) {
+                   back->send(idata.buffer, idata.len);
+               } else {
+                   back->special(TS_EOF);
+               }
+           }
+       } else if (n == 2) {
+           odata.busy = 0;
+           if (!odata.writeret) {
+               fprintf(stderr, "Unable to write to standard output\n");
+               exit(0);
+           }
+           bufchain_consume(&stdout_data, odata.lenwritten);
+           if (bufchain_size(&stdout_data) > 0)
+               try_output(0);
+           if (connopen && back->socket() != NULL) {
+               back->unthrottle(bufchain_size(&stdout_data) +
+                                bufchain_size(&stderr_data));
+           }
+       } else if (n == 3) {
+           edata.busy = 0;
+           if (!edata.writeret) {
+               fprintf(stderr, "Unable to write to standard output\n");
+               exit(0);
            }
+           bufchain_consume(&stderr_data, edata.lenwritten);
+           if (bufchain_size(&stderr_data) > 0)
+               try_output(1);
+           if (connopen && back->socket() != NULL) {
+               back->unthrottle(bufchain_size(&stdout_data) +
+                                bufchain_size(&stderr_data));
+           }
+       }
+       if (!reading && back->sendbuffer() < MAX_STDIN_BACKLOG) {
            SetEvent(idata.eventback);
+           reading = 1;
        }
-       if (!connopen || back->socket() == NULL)
+       if ((!connopen || back->socket() == NULL) &&
+           bufchain_size(&stdout_data) == 0 &&
+           bufchain_size(&stderr_data) == 0)
            break;                     /* we closed the connection */
     }
     WSACleanup();
-    return 0;
+    exitcode = back->exitcode();
+    if (exitcode < 0) {
+       fprintf(stderr, "Remote process exit code unavailable\n");
+       exitcode = 1;                  /* this is an error condition */
+    }
+    return exitcode;
 }