Add a configuration option for TCP keepalives (SO_KEEPALIVE), default off.
[u/mdw/putty] / scp.c
diff --git a/scp.c b/scp.c
index 00852bc..f7fa255 100644 (file)
--- a/scp.c
+++ b/scp.c
@@ -36,6 +36,10 @@ static int prev_stats_len = 0;
 static int scp_unsafe_mode = 0;
 static int errs = 0;
 static int gui_mode = 0;
+static int try_scp = 1;
+static int try_sftp = 1;
+static int main_cmd_is_sftp = 0;
+static int fallback_cmd_is_sftp = 0;
 static int using_sftp = 0;
 
 static Backend *back;
@@ -178,14 +182,13 @@ int from_backend(void *frontend, int is_stderr, const char *data, int datalen)
     unsigned char *p = (unsigned char *) data;
     unsigned len = (unsigned) datalen;
 
-    assert(len > 0);
-
     /*
      * stderr data is just spouted to local stderr and otherwise
      * ignored.
      */
     if (is_stderr) {
-       fwrite(data, 1, len, stderr);
+       if (len > 0)
+           fwrite(data, 1, len, stderr);
        return 0;
     }
 
@@ -195,7 +198,7 @@ int from_backend(void *frontend, int is_stderr, const char *data, int datalen)
     if (!outptr)
        return 0;
 
-    if (outlen > 0) {
+    if ((outlen > 0) && (len > 0)) {
        unsigned used = outlen;
        if (used > len)
            used = len;
@@ -263,7 +266,13 @@ static void ssh_scp_init(void)
        if (ssh_sftp_loop_iteration() < 0)
            return;                    /* doom */
     }
-    using_sftp = !ssh_fallback_cmd(backhandle);
+
+    /* Work out which backend we ended up using. */
+    if (!ssh_fallback_cmd(backhandle))
+       using_sftp = main_cmd_is_sftp;
+    else
+       using_sftp = fallback_cmd_is_sftp;
+
     if (verbose) {
        if (using_sftp)
            tell_user(stderr, "Using SFTP");
@@ -291,7 +300,7 @@ static void bump(char *fmt, ...)
     if (back != NULL && back->socket(backhandle) != NULL) {
        char ch;
        back->special(backhandle, TS_EOF);
-       ssh_scp_recv(&ch, 1);
+       ssh_scp_recv((unsigned char *) &ch, 1);
     }
 
     if (gui_mode)
@@ -403,18 +412,46 @@ static void do_cmd(char *host, char *user, char *cmd)
     cfg.portfwd[0] = cfg.portfwd[1] = '\0';
 
     /*
+     * Set up main and possibly fallback command depending on
+     * options specified by user.
      * Attempt to start the SFTP subsystem as a first choice,
      * falling back to the provided scp command if that fails.
      */
-    strcpy(cfg.remote_cmd, "sftp");
-    cfg.ssh_subsys = TRUE;
-    cfg.remote_cmd_ptr2 = cmd;
-    cfg.ssh_subsys2 = FALSE;
+    cfg.remote_cmd_ptr2 = NULL;
+    if (try_sftp) {
+       /* First choice is SFTP subsystem. */
+       main_cmd_is_sftp = 1;
+       strcpy(cfg.remote_cmd, "sftp");
+       cfg.ssh_subsys = TRUE;
+       if (try_scp) {
+           /* Fallback is to use the provided scp command. */
+           fallback_cmd_is_sftp = 0;
+           cfg.remote_cmd_ptr2 = cmd;
+           cfg.ssh_subsys2 = FALSE;
+       } else {
+           /* Since we're not going to try SCP, we may as well try
+            * harder to find an SFTP server, since in the current
+            * implementation we have a spare slot. */
+           fallback_cmd_is_sftp = 1;
+           /* see psftp.c for full explanation of this kludge */
+           cfg.remote_cmd_ptr2 = 
+               "test -x /usr/lib/sftp-server && exec /usr/lib/sftp-server\n"
+               "test -x /usr/local/lib/sftp-server && exec /usr/local/lib/sftp-server\n"
+               "exec sftp-server";
+           cfg.ssh_subsys2 = FALSE;
+       }
+    } else {
+       /* Don't try SFTP at all; just try the scp command. */
+       main_cmd_is_sftp = 0;
+       cfg.remote_cmd_ptr = cmd;
+       cfg.ssh_subsys = FALSE;
+    }
     cfg.nopty = TRUE;
 
     back = &ssh_backend;
 
-    err = back->init(NULL, &backhandle, &cfg, cfg.host, cfg.port, &realhost,0);
+    err = back->init(NULL, &backhandle, &cfg, cfg.host, cfg.port, &realhost, 
+                    0, cfg.tcp_keepalives);
     if (err != NULL)
        bump("ssh_init: %s", err);
     logctx = log_init(NULL, &cfg);
@@ -434,7 +471,7 @@ static void print_stats(char *name, unsigned long size, unsigned long done,
 {
     float ratebs;
     unsigned long eta;
-    char etastr[10];
+    char *etastr;
     int pct;
     int len;
     int elap;
@@ -450,13 +487,13 @@ static void print_stats(char *name, unsigned long size, unsigned long done,
        eta = size - done;
     else
        eta = (unsigned long) ((size - done) / ratebs);
-    sprintf(etastr, "%02ld:%02ld:%02ld",
-           eta / 3600, (eta % 3600) / 60, eta % 60);
+    etastr = dupprintf("%02ld:%02ld:%02ld",
+                      eta / 3600, (eta % 3600) / 60, eta % 60);
 
     pct = (int) (100 * (done * 1.0 / size));
 
     if (gui_mode) {
-       gui_update_stats(name, size, pct, elap, done, eta, 
+       gui_update_stats(name, size, pct, elap, done, eta,
                         (unsigned long) ratebs);
     } else {
        len = printf("\r%-25.25s | %10ld kB | %5.1f kB/s | ETA: %8s | %3d%%",
@@ -467,7 +504,11 @@ static void print_stats(char *name, unsigned long size, unsigned long done,
 
        if (done == size)
            printf("\n");
+
+       fflush(stdout);
     }
+
+    free(etastr);
 }
 
 /*
@@ -530,7 +571,7 @@ static int response(void)
     char ch, resp, rbuf[2048];
     int p;
 
-    if (ssh_scp_recv(&resp, 1) <= 0)
+    if (ssh_scp_recv((unsigned char *) &resp, 1) <= 0)
        bump("Lost connection");
 
     p = 0;
@@ -543,7 +584,7 @@ static int response(void)
       case 1:                         /* error */
       case 2:                         /* fatal error */
        do {
-           if (ssh_scp_recv(&ch, 1) <= 0)
+           if (ssh_scp_recv((unsigned char *) &ch, 1) <= 0)
                bump("Protocol error: Lost connection");
            rbuf[p++] = ch;
        } while (p < sizeof(rbuf) && ch != '\n');
@@ -559,11 +600,11 @@ static int response(void)
 
 int sftp_recvdata(char *buf, int len)
 {
-    return ssh_scp_recv(buf, len);
+    return ssh_scp_recv((unsigned char *) buf, len);
 }
 int sftp_senddata(char *buf, int len)
 {
-    back->send(backhandle, (unsigned char *) buf, len);
+    back->send(backhandle, buf, len);
     return 1;
 }
 
@@ -673,6 +714,7 @@ static int scp_sftp_preserve, scp_sftp_recursive;
 static unsigned long scp_sftp_mtime, scp_sftp_atime;
 static int scp_has_times;
 static struct fxp_handle *scp_sftp_filehandle;
+static struct fxp_xfer *scp_sftp_xfer;
 static uint64 scp_sftp_fileoffset;
 
 void scp_source_setup(char *target, int shouldbedir)
@@ -767,6 +809,8 @@ int scp_send_filename(char *name, unsigned long size, int modes)
            return 1;
        }
        scp_sftp_fileoffset = uint64_make(0, 0);
+       scp_sftp_xfer = xfer_upload_init(scp_sftp_filehandle,
+                                        scp_sftp_fileoffset);
        sfree(fullname);
        return 0;
     } else {
@@ -784,23 +828,23 @@ int scp_send_filedata(char *data, int len)
     if (using_sftp) {
        int ret;
        struct sftp_packet *pktin;
-       struct sftp_request *req, *rreq;
 
        if (!scp_sftp_filehandle) {
            return 1;
        }
 
-       sftp_register(req = fxp_write_send(scp_sftp_filehandle,
-                                          data, scp_sftp_fileoffset, len));
-       rreq = sftp_find_request(pktin = sftp_recv());
-       assert(rreq == req);
-       ret = fxp_write_recv(pktin, rreq);
-
-       if (!ret) {
-           tell_user(stderr, "error while writing: %s\n", fxp_error());
-           errs++;
-           return 1;
+       while (!xfer_upload_ready(scp_sftp_xfer)) {
+           pktin = sftp_recv();
+           ret = xfer_upload_gotpkt(scp_sftp_xfer, pktin);
+           if (!ret) {
+               tell_user(stderr, "error while writing: %s\n", fxp_error());
+               errs++;
+               return 1;
+           }
        }
+
+       xfer_upload_data(scp_sftp_xfer, data, len);
+
        scp_sftp_fileoffset = uint64_add32(scp_sftp_fileoffset, len);
        return 0;
     } else {
@@ -830,6 +874,12 @@ int scp_send_finish(void)
        struct sftp_request *req, *rreq;
        int ret;
 
+       while (!xfer_done(scp_sftp_xfer)) {
+           pktin = sftp_recv();
+           xfer_upload_gotpkt(scp_sftp_xfer, pktin);
+       }
+       xfer_cleanup(scp_sftp_xfer);
+
        if (!scp_sftp_filehandle) {
            return 1;
        }
@@ -1310,14 +1360,14 @@ int scp_get_sink_action(struct scp_sink_action *act)
        bufsize = 0;
 
        while (!done) {
-           if (ssh_scp_recv(&ch, 1) <= 0)
+           if (ssh_scp_recv((unsigned char *) &ch, 1) <= 0)
                return 1;
            if (ch == '\n')
                bump("Protocol error: Unexpected newline");
            i = 0;
            action = ch;
            do {
-               if (ssh_scp_recv(&ch, 1) <= 0)
+               if (ssh_scp_recv((unsigned char *) &ch, 1) <= 0)
                    bump("Lost connection");
                if (i >= bufsize) {
                    bufsize = i + 128;
@@ -1388,6 +1438,8 @@ int scp_accept_filexfer(void)
            return 1;
        }
        scp_sftp_fileoffset = uint64_make(0, 0);
+       scp_sftp_xfer = xfer_download_init(scp_sftp_filehandle,
+                                          scp_sftp_fileoffset);
        sfree(scp_sftp_currentname);
        return 0;
     } else {
@@ -1400,28 +1452,37 @@ int scp_recv_filedata(char *data, int len)
 {
     if (using_sftp) {
        struct sftp_packet *pktin;
-       struct sftp_request *req, *rreq;
-       int actuallen;
+       int ret, actuallen;
+       void *vbuf;
 
-       sftp_register(req = fxp_read_send(scp_sftp_filehandle,
-                                         scp_sftp_fileoffset, len));
-       rreq = sftp_find_request(pktin = sftp_recv());
-       assert(rreq == req);
-       actuallen = fxp_read_recv(pktin, rreq, data, len);
+       xfer_download_queue(scp_sftp_xfer);
+       pktin = sftp_recv();
+       ret = xfer_download_gotpkt(scp_sftp_xfer, pktin);
 
-       if (actuallen == -1 && fxp_error_type() != SSH_FX_EOF) {
+       if (ret < 0) {
            tell_user(stderr, "pscp: error while reading: %s", fxp_error());
            errs++;
            return -1;
        }
-       if (actuallen < 0)
+
+       if (xfer_download_data(scp_sftp_xfer, &vbuf, &actuallen)) {
+           /*
+            * This assertion relies on the fact that the natural
+            * block size used in the xfer manager is at most that
+            * used in this module. I don't like crossing layers in
+            * this way, but it'll do for now.
+            */
+           assert(actuallen <= len);
+           memcpy(data, vbuf, actuallen);
+           sfree(vbuf);
+       } else
            actuallen = 0;
 
        scp_sftp_fileoffset = uint64_add32(scp_sftp_fileoffset, actuallen);
 
        return actuallen;
     } else {
-       return ssh_scp_recv(data, len);
+       return ssh_scp_recv((unsigned char *) data, len);
     }
 }
 
@@ -1431,6 +1492,23 @@ int scp_finish_filerecv(void)
        struct sftp_packet *pktin;
        struct sftp_request *req, *rreq;
 
+       /*
+        * Ensure that xfer_done() will work correctly, so we can
+        * clean up any outstanding requests from the file
+        * transfer.
+        */
+       xfer_set_error(scp_sftp_xfer);
+       while (!xfer_done(scp_sftp_xfer)) {
+           void *vbuf;
+           int len;
+
+           pktin = sftp_recv();
+           xfer_download_gotpkt(scp_sftp_xfer, pktin);
+           if (xfer_download_data(scp_sftp_xfer, &vbuf, &len))
+               sfree(vbuf);
+       }
+       xfer_cleanup(scp_sftp_xfer);
+
        sftp_register(req = fxp_close_send(scp_sftp_filehandle));
        rreq = sftp_find_request(pktin = sftp_recv());
        assert(rreq == req);
@@ -1758,11 +1836,12 @@ static void sink(char *targ, char *src)
        received = 0;
        while (received < act.size) {
            char transbuf[4096];
-           int blksize, read;
+           unsigned long blksize;
+           int read;
            blksize = 4096;
-           if (blksize > (int)(act.size - received))
+           if (blksize > (act.size - received))
                blksize = act.size - received;
-           read = scp_recv_filedata(transbuf, blksize);
+           read = scp_recv_filedata(transbuf, (int)blksize);
            if (read <= 0)
                bump("Lost connection");
            if (wrerror)
@@ -1996,7 +2075,7 @@ static void get_dir_list(int argc, char *argv[])
     if (using_sftp) {
        scp_sftp_listdir(src);
     } else {
-       while (ssh_scp_recv(&c, 1) > 0)
+       while (ssh_scp_recv((unsigned char *) &c, 1) > 0)
            tell_char(stdout, c);
     }
 }
@@ -2011,7 +2090,7 @@ static void usage(void)
     printf("Usage: pscp [options] [user@]host:source target\n");
     printf
        ("       pscp [options] source [source...] [user@]host:target\n");
-    printf("       pscp [options] -ls user@host:filespec\n");
+    printf("       pscp [options] -ls [user@]host:filespec\n");
     printf("Options:\n");
     printf("  -p        preserve file attributes\n");
     printf("  -q        quiet, don't show statistics\n");
@@ -2026,6 +2105,9 @@ static void usage(void)
     printf("  -i key    private key file for authentication\n");
     printf("  -batch    disable all interactive prompts\n");
     printf("  -unsafe   allow server-side wildcards (DANGEROUS)\n");
+    printf("  -V        print version information\n");
+    printf("  -sftp     force use of SFTP protocol\n");
+    printf("  -scp      force use of SCP protocol\n");
 #if 0
     /*
      * -gui is an internal option, used by GUI front ends to get
@@ -2040,6 +2122,12 @@ static void usage(void)
     cleanup_exit(1);
 }
 
+void version(void)
+{
+    printf("pscp: %s\n", ver);
+    cleanup_exit(1);
+}
+
 void cmdline_error(char *p, ...)
 {
     va_list ap;
@@ -2091,6 +2179,8 @@ int psftp_main(int argc, char *argv[])
            statistics = 0;
        } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "-?") == 0) {
            usage();
+       } else if (strcmp(argv[i], "-V") == 0) {
+            version();
        } else if (strcmp(argv[i], "-gui") == 0 && i + 1 < argc) {
            gui_enable(argv[++i]);
            gui_mode = 1;
@@ -2101,6 +2191,10 @@ int psftp_main(int argc, char *argv[])
            console_batch_mode = 1;
        } else if (strcmp(argv[i], "-unsafe") == 0) {
            scp_unsafe_mode = 1;
+       } else if (strcmp(argv[i], "-sftp") == 0) {
+           try_scp = 0; try_sftp = 1;
+       } else if (strcmp(argv[i], "-scp") == 0) {
+           try_scp = 1; try_sftp = 0;
        } else if (strcmp(argv[i], "--") == 0) {
            i++;
            break;
@@ -2133,13 +2227,19 @@ int psftp_main(int argc, char *argv[])
     if (back != NULL && back->socket(backhandle) != NULL) {
        char ch;
        back->special(backhandle, TS_EOF);
-       ssh_scp_recv(&ch, 1);
+       ssh_scp_recv((unsigned char *) &ch, 1);
     }
     random_save_seed();
 
     if (gui_mode)
        gui_send_errcount(list, errs);
 
+    cmdline_cleanup();
+    console_provide_logctx(NULL);
+    back->free(backhandle);
+    backhandle = NULL;
+    back = NULL;
+    sk_cleanup();
     return (errs == 0 ? 0 : 1);
 }