Add a new back-end function to return the exit code of the remote
[u/mdw/putty] / plink.c
diff --git a/plink.c b/plink.c
index 2802feb..f8a9f7c 100644 (file)
--- a/plink.c
+++ b/plink.c
@@ -162,6 +162,63 @@ void askcipher(char *ciphername, int cs)
     }
 }
 
+/*
+ * Ask whether to wipe a session log file before writing to it.
+ * Returns 2 for wipe, 1 for append, 0 for cancel (don't log).
+ */
+int askappend(char *filename)
+{
+    HANDLE hin;
+    DWORD savemode, i;
+
+    static const char msgtemplate[] =
+       "The session log file \"%.*s\" already exists.\n"
+       "You can overwrite it with a new session log,\n"
+       "append your session log to the end of it,\n"
+       "or disable session logging for this session.\n"
+       "Enter \"y\" to wipe the file, \"n\" to append to it,\n"
+       "or just press Return to disable logging.\n"
+       "Wipe the log file? (y/n, Return cancels logging) ";
+
+    char line[32];
+
+    fprintf(stderr, msgtemplate, FILENAME_MAX, filename);
+    fflush(stderr);
+
+    hin = GetStdHandle(STD_INPUT_HANDLE);
+    GetConsoleMode(hin, &savemode);
+    SetConsoleMode(hin, (savemode | ENABLE_ECHO_INPUT |
+                        ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT));
+    ReadFile(hin, line, sizeof(line) - 1, &i, NULL);
+    SetConsoleMode(hin, savemode);
+
+    if (line[0] == 'y' || line[0] == 'Y')
+       return 2;
+    else if (line[0] == 'n' || line[0] == 'N')
+       return 1;
+    else
+       return 0;
+}
+
+/*
+ * Warn about the obsolescent key file format.
+ */
+void old_keyfile_warning(void)
+{
+    static const char message[] =
+       "You are loading an SSH 2 private key which has an\n"
+       "old version of the file format. This means your key\n"
+       "file is not fully tamperproof. Future versions of\n"
+       "PuTTY may stop supporting this private key format,\n"
+       "so we recommend you convert your key to the new\n"
+       "format.\n"
+       "\n"
+       "Once the key is loaded into PuTTYgen, you can perform\n"
+       "this conversion simply by saving it again.\n";
+
+    fputs(message, stderr);
+}
+
 HANDLE inhandle, outhandle, errhandle;
 DWORD orig_console_mode;
 
@@ -314,11 +371,7 @@ void try_output(int is_stderr)
 
 int from_backend(int is_stderr, char *data, int len)
 {
-    int pos;
-    DWORD ret;
     HANDLE h = (is_stderr ? errhandle : outhandle);
-    void *writedata;
-    int writelen;
     int osize, esize;
 
     if (is_stderr) {
@@ -343,12 +396,17 @@ static void usage(void)
     printf("PuTTY Link: command-line connection utility\n");
     printf("%s\n", ver);
     printf("Usage: plink [options] [user@]host [command]\n");
+    printf("       (\"host\" can also be a PuTTY saved session name)\n");
     printf("Options:\n");
     printf("  -v        show verbose messages\n");
     printf("  -ssh      force use of ssh protocol\n");
     printf("  -P port   connect to specified port\n");
     printf("  -pw passw login with specified password\n");
     printf("  -m file   read remote command(s) from file\n");
+    printf("  -L listen-port:host:port   Forward local port to "
+          "remote address\n");
+    printf("  -R listen-port:host:port   Forward remote port to"
+          " local address\n");
     exit(1);
 }
 
@@ -356,7 +414,8 @@ char *do_select(SOCKET skt, int startup)
 {
     int events;
     if (startup) {
-       events = FD_READ | FD_WRITE | FD_OOB | FD_CLOSE | FD_ACCEPT;
+       events = (FD_CONNECT | FD_READ | FD_WRITE |
+                 FD_OOB | FD_CLOSE | FD_ACCEPT);
     } else {
        events = 0;
     }
@@ -385,6 +444,8 @@ int main(int argc, char **argv)
     SOCKET *sklist;
     int skcount, sksize;
     int connopen;
+    int exitcode;
+    char extra_portfwd[sizeof(cfg.portfwd)];
 
     ssh_get_line = get_line;
 
@@ -430,6 +491,9 @@ int main(int argc, char **argv)
            } else if (!strcmp(p, "-telnet")) {
                default_protocol = cfg.protocol = PROT_TELNET;
                default_port = cfg.port = 23;
+           } else if (!strcmp(p, "-rlogin")) {
+               default_protocol = cfg.protocol = PROT_RLOGIN;
+               default_port = cfg.port = 513;
            } else if (!strcmp(p, "-raw")) {
                default_protocol = cfg.protocol = PROT_RAW;
            } else if (!strcmp(p, "-v")) {
@@ -443,6 +507,24 @@ int main(int argc, char **argv)
                --argc, username = *++argv;
                strncpy(cfg.username, username, sizeof(cfg.username));
                cfg.username[sizeof(cfg.username) - 1] = '\0';
+           } else if ((!strcmp(p, "-L") || !strcmp(p, "-R")) && argc > 1) {
+               char *fwd, *ptr, *q;
+               int i=0;
+               --argc, fwd = *++argv;
+               ptr = extra_portfwd;
+               /* if multiple forwards, find end of list */
+               if (ptr[0]=='R' || ptr[0]=='L') {
+                   for (i = 0; i < sizeof(extra_portfwd) - 2; i++)
+                       if (ptr[i]=='\000' && ptr[i+1]=='\000')
+                           break;
+                   ptr = ptr + i + 1;  /* point to next forward slot */
+               }
+               ptr[0] = p[1];  /* insert a 'L' or 'R' at the start */
+               strncpy(ptr+1, fwd, sizeof(extra_portfwd) - i);
+               q = strchr(ptr, ':');
+               if (q) *q = '\t';      /* replace first : with \t */
+               ptr[strlen(ptr)+1] = '\000';    /* append two '\000' */
+               extra_portfwd[sizeof(extra_portfwd) - 1] = '\0';
            } else if (!strcmp(p, "-m") && argc > 1) {
                char *filename, *command;
                int cmdlen, cmdsize;
@@ -471,6 +553,7 @@ int main(int argc, char **argv)
                    command[cmdlen++] = d;
                } while (c != EOF);
                cfg.remote_cmd_ptr = command;
+               cfg.remote_cmd_ptr2 = NULL;
                cfg.nopty = TRUE;      /* command => no terminal */
            } else if (!strcmp(p, "-P") && argc > 1) {
                --argc, portnumber = atoi(*++argv);
@@ -592,6 +675,32 @@ int main(int argc, char **argv)
        usage();
     }
 
+    /*
+     * Trim leading whitespace off the hostname if it's there.
+     */
+    {
+       int space = strspn(cfg.host, " \t");
+       memmove(cfg.host, cfg.host+space, 1+strlen(cfg.host)-space);
+    }
+
+    /* See if host is of the form user@host */
+    if (cfg.host[0] != '\0') {
+       char *atsign = strchr(cfg.host, '@');
+       /* Make sure we're not overflowing the user field */
+       if (atsign) {
+           if (atsign - cfg.host < sizeof cfg.username) {
+               strncpy(cfg.username, cfg.host, atsign - cfg.host);
+               cfg.username[atsign - cfg.host] = '\0';
+           }
+           memmove(cfg.host, atsign + 1, 1 + strlen(atsign + 1));
+       }
+    }
+
+    /*
+     * Trim a colon suffix off the hostname if it's there.
+     */
+    cfg.host[strcspn(cfg.host, ":")] = '\0';
+
     if (!*cfg.remote_cmd_ptr)
        flags |= FLAG_INTERACTIVE;
 
@@ -615,6 +724,30 @@ int main(int argc, char **argv)
     }
 
     /*
+     * Add extra port forwardings (accumulated on command line) to
+     * cfg.
+     */
+    {
+       int i;
+       char *p;
+       p = extra_portfwd;
+       i = 0;
+       while (cfg.portfwd[i])
+           i += strlen(cfg.portfwd+i) + 1;
+       while (*p) {
+           if (strlen(p)+2 > sizeof(cfg.portfwd)-i) {
+               fprintf(stderr, "Internal fault: not enough space for all"
+                       " port forwardings\n");
+               break;
+           }
+           strncpy(cfg.portfwd+i, p, sizeof(cfg.portfwd)-i-1);
+           i += strlen(cfg.portfwd+i) + 1;
+           cfg.portfwd[i] = '\0';
+           p += strlen(p)+1;
+       }
+    }
+
+    /*
      * Select port.
      */
     if (portnumber != -1)
@@ -644,8 +777,11 @@ int main(int argc, char **argv)
     {
        char *error;
        char *realhost;
+       /* nodelay is only useful if stdin is a character device (console) */
+       int nodelay = cfg.tcp_nodelay &&
+           (GetFileType(GetStdHandle(STD_INPUT_HANDLE)) == FILE_TYPE_CHAR);
 
-       error = back->init(cfg.host, cfg.port, &realhost);
+       error = back->init(cfg.host, cfg.port, &realhost, nodelay);
        if (error) {
            fprintf(stderr, "Unable to open connection:\n%s", error);
            return 1;
@@ -770,6 +906,8 @@ int main(int argc, char **argv)
                if (!WSAEnumNetworkEvents(socket, NULL, &things)) {
                    noise_ultralight(socket);
                    noise_ultralight(things.lNetworkEvents);
+                   if (things.lNetworkEvents & FD_CONNECT)
+                       connopen &= select_result(wp, (LPARAM) FD_CONNECT);
                    if (things.lNetworkEvents & FD_READ)
                        connopen &= select_result(wp, (LPARAM) FD_READ);
                    if (things.lNetworkEvents & FD_CLOSE)
@@ -786,10 +924,12 @@ int main(int argc, char **argv)
        } else if (n == 1) {
            reading = 0;
            noise_ultralight(idata.len);
-           if (idata.len > 0) {
-               back->send(idata.buffer, idata.len);
-           } else {
-               back->special(TS_EOF);
+           if (connopen && back->socket() != NULL) {
+               if (idata.len > 0) {
+                   back->send(idata.buffer, idata.len);
+               } else {
+                   back->special(TS_EOF);
+               }
            }
        } else if (n == 2) {
            odata.busy = 0;
@@ -800,8 +940,10 @@ int main(int argc, char **argv)
            bufchain_consume(&stdout_data, odata.lenwritten);
            if (bufchain_size(&stdout_data) > 0)
                try_output(0);
-           back->unthrottle(bufchain_size(&stdout_data) +
-                            bufchain_size(&stderr_data));
+           if (connopen && back->socket() != NULL) {
+               back->unthrottle(bufchain_size(&stdout_data) +
+                                bufchain_size(&stderr_data));
+           }
        } else if (n == 3) {
            edata.busy = 0;
            if (!edata.writeret) {
@@ -811,16 +953,25 @@ int main(int argc, char **argv)
            bufchain_consume(&stderr_data, edata.lenwritten);
            if (bufchain_size(&stderr_data) > 0)
                try_output(1);
-           back->unthrottle(bufchain_size(&stdout_data) +
-                            bufchain_size(&stderr_data));
+           if (connopen && back->socket() != NULL) {
+               back->unthrottle(bufchain_size(&stdout_data) +
+                                bufchain_size(&stderr_data));
+           }
        }
        if (!reading && back->sendbuffer() < MAX_STDIN_BACKLOG) {
            SetEvent(idata.eventback);
            reading = 1;
        }
-       if (!connopen || back->socket() == NULL)
+       if ((!connopen || back->socket() == NULL) &&
+           bufchain_size(&stdout_data) == 0 &&
+           bufchain_size(&stderr_data) == 0)
            break;                     /* we closed the connection */
     }
     WSACleanup();
-    return 0;
+    exitcode = back->exitcode();
+    if (exitcode < 0) {
+       fprintf(stderr, "Remote process exit code unavailable\n");
+       exitcode = 1;                  /* this is an error condition */
+    }
+    return exitcode;
 }