Introduce framework for authenticating with the local X server.
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 0ac24d2..c0f06ef 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -166,7 +166,7 @@ static const char *const ssh2_disconnect_reasons[] = {
 
 #define translate(x) if (type == x) return #x
 #define translatec(x,ctx) if (type == x && (pkt_ctx & ctx)) return #x
-char *ssh1_pkt_type(int type)
+static char *ssh1_pkt_type(int type)
 {
     translate(SSH1_MSG_DISCONNECT);
     translate(SSH1_SMSG_PUBLIC_KEY);
@@ -211,7 +211,7 @@ char *ssh1_pkt_type(int type)
     translate(SSH1_CMSG_AUTH_CCARD_RESPONSE);
     return "unknown";
 }
-char *ssh2_pkt_type(int pkt_ctx, int type)
+static char *ssh2_pkt_type(int pkt_ctx, int type)
 {
     translate(SSH2_MSG_DISCONNECT);
     translate(SSH2_MSG_IGNORE);
@@ -291,22 +291,6 @@ enum { PKT_END, PKT_INT, PKT_CHAR, PKT_DATA, PKT_STR, PKT_BIGNUM };
 
 typedef struct ssh_tag *Ssh;
 
-extern char *x11_init(Socket *, char *, void *, void *);
-extern void x11_close(Socket);
-extern int x11_send(Socket, char *, int);
-extern void *x11_invent_auth(char *, int, char *, int);
-extern void x11_unthrottle(Socket s);
-extern void x11_override_throttle(Socket s, int enable);
-
-extern char *pfd_newconnect(Socket * s, char *hostname, int port, void *c);
-extern char *pfd_addforward(char *desthost, int destport, char *srcaddr,
-                           int port, void *backhandle);
-extern void pfd_close(Socket s);
-extern int pfd_send(Socket s, char *data, int len);
-extern void pfd_confirm(Socket s);
-extern void pfd_unthrottle(Socket s);
-extern void pfd_override_throttle(Socket s, int enable);
-
 static void ssh2_pkt_init(Ssh, int pkt_type);
 static void ssh2_pkt_addbool(Ssh, unsigned char value);
 static void ssh2_pkt_adduint32(Ssh, unsigned long value);
@@ -314,7 +298,7 @@ static void ssh2_pkt_addstring_start(Ssh);
 static void ssh2_pkt_addstring_str(Ssh, char *data);
 static void ssh2_pkt_addstring_data(Ssh, char *data, int len);
 static void ssh2_pkt_addstring(Ssh, char *data);
-static char *ssh2_mpint_fmt(Bignum b, int *len);
+static unsigned char *ssh2_mpint_fmt(Bignum b, int *len);
 static void ssh2_pkt_addmp(Ssh, Bignum b);
 static int ssh2_pkt_construct(Ssh);
 static void ssh2_pkt_send(Ssh);
@@ -638,16 +622,10 @@ struct ssh_tag {
     int (*s_rdpkt) (Ssh ssh, unsigned char **data, int *datalen);
 };
 
-#define logevent(s) do { \
-    logevent(ssh->frontend, s); \
-    if ((flags & FLAG_STDERR) && (flags & FLAG_VERBOSE)) { \
-       fprintf(stderr, "%s\n", s); \
-       fflush(stderr); \
-    } \
-} while (0)
+#define logevent(s) logevent(ssh->frontend, s)
 
 /* logevent, only printf-formatted. */
-void logeventf(Ssh ssh, char *fmt, ...)
+static void logeventf(Ssh ssh, char *fmt, ...)
 {
     va_list ap;
     char *buf;
@@ -1221,7 +1199,7 @@ static void s_wrpkt(Ssh ssh)
 {
     int len, backlog;
     len = s_wrpkt_prepare(ssh);
-    backlog = sk_write(ssh->s, ssh->pktout.data, len);
+    backlog = sk_write(ssh->s, (char *)ssh->pktout.data, len);
     if (backlog > SSH_MAX_BACKLOG)
        ssh_throttle_all(ssh, 1, backlog);
 }
@@ -1267,7 +1245,7 @@ static void construct_packet(Ssh ssh, int pkttype, va_list ap1, va_list ap2)
            break;
          case PKT_STR:
            argp = va_arg(ap1, unsigned char *);
-           arglen = strlen(argp);
+           arglen = strlen((char *)argp);
            pktlen += 4 + arglen;
            break;
          case PKT_BIGNUM:
@@ -1302,7 +1280,7 @@ static void construct_packet(Ssh ssh, int pkttype, va_list ap1, va_list ap2)
            break;
          case PKT_STR:
            argp = va_arg(ap2, unsigned char *);
-           arglen = strlen(argp);
+           arglen = strlen((char *)argp);
            PUT_32BIT(p, arglen);
            memcpy(p + 4, argp, arglen);
            p += 4 + arglen;
@@ -1433,7 +1411,7 @@ static void ssh2_pkt_addstring(Ssh ssh, char *data)
     ssh2_pkt_addstring_start(ssh);
     ssh2_pkt_addstring_str(ssh, data);
 }
-static char *ssh2_mpint_fmt(Bignum b, int *len)
+static unsigned char *ssh2_mpint_fmt(Bignum b, int *len)
 {
     unsigned char *p;
     int i, n = (bignum_bitcount(b) + 7) / 8;
@@ -1456,7 +1434,7 @@ static void ssh2_pkt_addmp(Ssh ssh, Bignum b)
     int len;
     p = ssh2_mpint_fmt(b, &len);
     ssh2_pkt_addstring_start(ssh);
-    ssh2_pkt_addstring_data(ssh, p, len);
+    ssh2_pkt_addstring_data(ssh, (char *)p, len);
     sfree(p);
 }
 
@@ -1527,7 +1505,7 @@ static void ssh2_pkt_send(Ssh ssh)
     int len;
     int backlog;
     len = ssh2_pkt_construct(ssh);
-    backlog = sk_write(ssh->s, ssh->pktout.data, len);
+    backlog = sk_write(ssh->s, (char *)ssh->pktout.data, len);
     if (backlog > SSH_MAX_BACKLOG)
        ssh_throttle_all(ssh, 1, backlog);
 }
@@ -1562,7 +1540,8 @@ static void ssh2_pkt_defer(Ssh ssh)
 static void ssh_pkt_defersend(Ssh ssh)
 {
     int backlog;
-    backlog = sk_write(ssh->s, ssh->deferred_send_data, ssh->deferred_len);
+    backlog = sk_write(ssh->s, (char *)ssh->deferred_send_data,
+                      ssh->deferred_len);
     ssh->deferred_len = ssh->deferred_size = 0;
     sfree(ssh->deferred_send_data);
     ssh->deferred_send_data = NULL;
@@ -1628,7 +1607,7 @@ static void ssh2_pkt_getstring(Ssh ssh, char **p, int *length)
     ssh->pktin.savedpos += 4;
     if (ssh->pktin.length - ssh->pktin.savedpos < *length)
        return;
-    *p = ssh->pktin.data + ssh->pktin.savedpos;
+    *p = (char *)(ssh->pktin.data + ssh->pktin.savedpos);
     ssh->pktin.savedpos += *length;
 }
 static Bignum ssh2_pkt_getmp(Ssh ssh)
@@ -1644,7 +1623,7 @@ static Bignum ssh2_pkt_getmp(Ssh ssh)
        bombout((ssh,"internal error: Can't handle negative mpints"));
        return NULL;
     }
-    b = bignum_from_bytes(p, length);
+    b = bignum_from_bytes((unsigned char *)p, length);
     return b;
 }
 
@@ -1694,18 +1673,18 @@ static void ssh2_add_sigblob(Ssh ssh, void *pkblob_v, int pkblob_len,
        if (len != siglen) {
            unsigned char newlen[4];
            ssh2_pkt_addstring_start(ssh);
-           ssh2_pkt_addstring_data(ssh, sigblob, pos);
+           ssh2_pkt_addstring_data(ssh, (char *)sigblob, pos);
            /* dmemdump(sigblob, pos); */
            pos += 4;                  /* point to start of actual sig */
            PUT_32BIT(newlen, len);
-           ssh2_pkt_addstring_data(ssh, newlen, 4);
+           ssh2_pkt_addstring_data(ssh, (char *)newlen, 4);
            /* dmemdump(newlen, 4); */
            newlen[0] = 0;
            while (len-- > siglen) {
-               ssh2_pkt_addstring_data(ssh, newlen, 1);
+               ssh2_pkt_addstring_data(ssh, (char *)newlen, 1);
                /* dmemdump(newlen, 1); */
            }
-           ssh2_pkt_addstring_data(ssh, sigblob+pos, siglen);
+           ssh2_pkt_addstring_data(ssh, (char *)(sigblob+pos), siglen);
            /* dmemdump(sigblob+pos, siglen); */
            return;
        }
@@ -1714,7 +1693,7 @@ static void ssh2_add_sigblob(Ssh ssh, void *pkblob_v, int pkblob_len,
     }
 
     ssh2_pkt_addstring_start(ssh);
-    ssh2_pkt_addstring_data(ssh, sigblob, sigblob_len);
+    ssh2_pkt_addstring_data(ssh, (char *)sigblob, sigblob_len);
 }
 
 /*
@@ -1865,7 +1844,7 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
                s->i = -1;
            } else if (s->i < sizeof(s->version) - 1)
                s->version[s->i++] = c;
-       } else if (c == '\n')
+       } else if (c == '\012')
            break;
     }
 
@@ -1915,7 +1894,7 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
        sha_string(&ssh->exhashbase, s->vstring, strcspn(s->vstring, "\r\n"));
        sprintf(vlog, "We claim version: %s", verstring);
        logevent(vlog);
-       strcat(verstring, "\n");
+       strcat(verstring, "\012");
        logevent("Using SSH protocol version 2");
        sk_write(ssh->s, verstring, strlen(verstring));
        ssh->protocol = ssh2_protocol;
@@ -1931,7 +1910,7 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
                sshver);
        sprintf(vlog, "We claim version: %s", verstring);
        logevent(vlog);
-       strcat(verstring, "\n");
+       strcat(verstring, "\012");
 
        logevent("Using SSH protocol version 1");
        sk_write(ssh->s, verstring, strlen(verstring));
@@ -2014,7 +1993,7 @@ static int ssh_closing(Plug plug, char *error_msg, int error_code,
 static int ssh_receive(Plug plug, int urgent, char *data, int len)
 {
     Ssh ssh = (Ssh) plug;
-    ssh_gotdata(ssh, data, len);
+    ssh_gotdata(ssh, (unsigned char *)data, len);
     if (ssh->state == SSH_STATE_CLOSED) {
        if (ssh->s) {
            sk_close(ssh->s);
@@ -2069,7 +2048,7 @@ static char *connect_to_host(Ssh ssh, char *host, int port,
      */
     logeventf(ssh, "Looking up host \"%s\"", host);
     addr = name_lookup(host, port, realhost);
-    if ((err = sk_addr_error(addr)))
+    if ((err = sk_addr_error(addr)) != NULL)
        return err;
 
     /*
@@ -2082,7 +2061,7 @@ static char *connect_to_host(Ssh ssh, char *host, int port,
     }
     ssh->fn = &fn_table;
     ssh->s = new_connection(addr, *realhost, port, 0, 1, nodelay, (Plug) ssh);
-    if ((err = sk_socket_error(ssh->s))) {
+    if ((err = sk_socket_error(ssh->s)) != NULL) {
        ssh->s = NULL;
        return err;
     }
@@ -2146,7 +2125,7 @@ static void ssh_throttle_all(Ssh ssh, int enable, int bufsize)
  */
 
 /* Set up a username or password input loop on a given buffer. */
-void setup_userpass_input(Ssh ssh, char *buffer, int buflen, int echo)
+static void setup_userpass_input(Ssh ssh, char *buffer, int buflen, int echo)
 {
     ssh->userpass_input_buffer = buffer;
     ssh->userpass_input_buflen = buflen;
@@ -2160,7 +2139,7 @@ void setup_userpass_input(Ssh ssh, char *buffer, int buflen, int echo)
  * buffer), <0 for failure (user hit ^C/^D, bomb out and exit), 0
  * for inconclusive (keep waiting for more input please).
  */
-int process_userpass_input(Ssh ssh, unsigned char *in, int inlen)
+static int process_userpass_input(Ssh ssh, unsigned char *in, int inlen)
 {
     char c;
 
@@ -2521,7 +2500,7 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                    s->p += ssh1_read_bignum(s->p, &s->key.modulus);
                    s->commentlen = GET_32BIT(s->p);
                    s->p += 4;
-                   s->commentp = s->p;
+                   s->commentp = (char *)s->p;
                    s->p += s->commentlen;
                    send_packet(ssh, SSH1_CMSG_AUTH_RSA,
                                PKT_BIGNUM, s->key.modulus, PKT_END);
@@ -3040,10 +3019,12 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen, int ispkt)
        logevent("Requesting X11 forwarding");
        ssh->x11auth = x11_invent_auth(proto, sizeof(proto),
                                       data, sizeof(data));
+        x11_get_real_auth(ssh->x11auth, cfg.x11_display);
        if (ssh->v1_local_protoflags & SSH1_PROTOFLAG_SCREEN_NUMBER) {
            send_packet(ssh, SSH1_CMSG_X11_REQUEST_FORWARDING,
                        PKT_STR, proto, PKT_STR, data,
-                       PKT_INT, 0, PKT_END);
+                       PKT_INT, x11_get_screen_number(cfg.x11_display),
+                       PKT_END);
        } else {
            send_packet(ssh, SSH1_CMSG_X11_REQUEST_FORWARDING,
                        PKT_STR, proto, PKT_STR, data, PKT_END);
@@ -3271,7 +3252,7 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                int bufsize =
                    from_backend(ssh->frontend,
                                 ssh->pktin.type == SSH1_SMSG_STDERR_DATA,
-                                ssh->pktin.body + 4, len);
+                                (char *)(ssh->pktin.body) + 4, len);
                if (!ssh->v1_stdout_throttling && bufsize > SSH1_BUFFER_LIMIT) {
                    ssh->v1_stdout_throttling = 1;
                    ssh1_throttle(ssh, +1);
@@ -3352,7 +3333,8 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                c->ssh = ssh;
 
                hostsize = GET_32BIT(ssh->pktin.body+4);
-               for(h = host, p = ssh->pktin.body+8; hostsize != 0; hostsize--) {
+               for (h = host, p = (char *)(ssh->pktin.body+8);
+                    hostsize != 0; hostsize--) {
                    if (h+1 < host+sizeof(host))
                        *h++ = *p;
                    p++;
@@ -3484,10 +3466,10 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                    int bufsize;
                    switch (c->type) {
                      case CHAN_X11:
-                       bufsize = x11_send(c->u.x11.s, p, len);
+                       bufsize = x11_send(c->u.x11.s, (char *)p, len);
                        break;
                      case CHAN_SOCKDATA:
-                       bufsize = pfd_send(c->u.pfd.s, p, len);
+                       bufsize = pfd_send(c->u.pfd.s, (char *)p, len);
                        break;
                      case CHAN_AGENT:
                        /* Data for an agent message. Buffer it. */
@@ -3621,8 +3603,9 @@ static int in_commasep_string(char *needle, char *haystack, int haylen)
 /*
  * SSH2 key creation method.
  */
-static void ssh2_mkkey(Ssh ssh, Bignum K, char *H, char *sessid, char chr,
-                      char *keyspace)
+static void ssh2_mkkey(Ssh ssh, Bignum K, unsigned char *H,
+                      unsigned char *sessid, char chr,
+                      unsigned char *keyspace)
 {
     SHA_State s;
     /* First 20 bytes. */
@@ -4048,7 +4031,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen, int ispkt)
     s->hkey = ssh->hostkey->newkey(s->hostkeydata, s->hostkeylen);
     if (!s->hkey ||
        !ssh->hostkey->verifysig(s->hkey, s->sigdata, s->siglen,
-                                s->exchange_hash, 20)) {
+                                (char *)s->exchange_hash, 20)) {
        bombout((ssh,"Server's host key did not match the signature supplied"));
        crReturn(0);
     }
@@ -4586,13 +4569,13 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                            logevent("This key matches configured key file");
                            s->tried_pubkey_config = 1;
                        }
-                       s->pkblob = s->p;
+                       s->pkblob = (char *)s->p;
                        s->p += s->pklen;
                        s->alglen = GET_32BIT(s->pkblob);
                        s->alg = s->pkblob + 4;
                        s->commentlen = GET_32BIT(s->p);
                        s->p += 4;
-                       s->commentp = s->p;
+                       s->commentp = (char *)s->p;
                        s->p += s->commentlen;
                        ssh2_pkt_init(ssh, SSH2_MSG_USERAUTH_REQUEST);
                        ssh2_pkt_addstring(ssh, s->username);
@@ -4698,8 +4681,10 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                 * First, offer the public blob to see if the server is
                 * willing to accept it.
                 */
-               pub_blob = ssh2_userkey_loadpub(cfg.keyfile, &algorithm,
-                                               &pub_blob_len);
+               pub_blob =
+                   (unsigned char *)ssh2_userkey_loadpub(cfg.keyfile,
+                                                         &algorithm,
+                                                         &pub_blob_len);
                if (pub_blob) {
                    ssh2_pkt_init(ssh, SSH2_MSG_USERAUTH_REQUEST);
                    ssh2_pkt_addstring(ssh, s->username);
@@ -4708,7 +4693,8 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                    ssh2_pkt_addbool(ssh, FALSE);       /* no signature included */
                    ssh2_pkt_addstring(ssh, algorithm);
                    ssh2_pkt_addstring_start(ssh);
-                   ssh2_pkt_addstring_data(ssh, pub_blob, pub_blob_len);
+                   ssh2_pkt_addstring_data(ssh, (char *)pub_blob,
+                                           pub_blob_len);
                    ssh2_pkt_send(ssh);
                    logevent("Offered public key");     /* FIXME */
 
@@ -4909,7 +4895,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                    ssh2_pkt_addstring(ssh, key->alg->name);
                    pkblob = key->alg->public_blob(key->data, &pkblob_len);
                    ssh2_pkt_addstring_start(ssh);
-                   ssh2_pkt_addstring_data(ssh, pkblob, pkblob_len);
+                   ssh2_pkt_addstring_data(ssh, (char *)pkblob, pkblob_len);
 
                    /*
                     * The data to be signed is:
@@ -4925,7 +4911,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                    memcpy(sigdata + 4, ssh->v2_session_id, 20);
                    memcpy(sigdata + 24, ssh->pktout.data + 5,
                           ssh->pktout.length - 5);
-                   sigblob = key->alg->sign(key->data, sigdata,
+                   sigblob = key->alg->sign(key->data, (char *)sigdata,
                                             sigdata_len, &sigblob_len);
                    ssh2_add_sigblob(ssh, pkblob, pkblob_len,
                                     sigblob, sigblob_len);
@@ -4957,6 +4943,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
                ssh2_pkt_addstring(ssh, "password");
                ssh2_pkt_addbool(ssh, FALSE);
                ssh2_pkt_addstring(ssh, s->password);
+               memset(s->password, 0, sizeof(s->password));
                ssh2_pkt_defer(ssh);
                /*
                 * We'll include a string that's an exact multiple of the
@@ -5082,6 +5069,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
        logevent("Requesting X11 forwarding");
        ssh->x11auth = x11_invent_auth(proto, sizeof(proto),
                                       data, sizeof(data));
+        x11_get_real_auth(ssh->x11auth, cfg.x11_display);
        ssh2_pkt_init(ssh, SSH2_MSG_CHANNEL_REQUEST);
        ssh2_pkt_adduint32(ssh, ssh->mainchan->remoteid);
        ssh2_pkt_addstring(ssh, "x11-req");
@@ -5089,7 +5077,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
        ssh2_pkt_addbool(ssh, 0);              /* many connections */
        ssh2_pkt_addstring(ssh, proto);
        ssh2_pkt_addstring(ssh, data);
-       ssh2_pkt_adduint32(ssh, 0);            /* screen number */
+       ssh2_pkt_adduint32(ssh, x11_get_screen_number(cfg.x11_display));
        ssh2_pkt_send(ssh);
 
        do {
@@ -5813,7 +5801,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, int ispkt)
            /*
             * We have spare data. Add it to the channel buffer.
             */
-           ssh2_add_channel_data(ssh->mainchan, in, inlen);
+           ssh2_add_channel_data(ssh->mainchan, (char *)in, inlen);
            s->try_send = TRUE;
        }
        if (s->try_send) {
@@ -5941,6 +5929,8 @@ static char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh->overall_bufsize = 0;
     ssh->fallback_cmd = 0;
 
+    ssh->protocol = NULL;
+
     p = connect_to_host(ssh, host, port, realhost, nodelay);
     if (p != NULL)
        return p;
@@ -5958,7 +5948,7 @@ static int ssh_send(void *handle, char *buf, int len)
     if (ssh == NULL || ssh->s == NULL || ssh->protocol == NULL)
        return 0;
 
-    ssh->protocol(ssh, buf, len, 0);
+    ssh->protocol(ssh, (unsigned char *)buf, len, 0);
 
     return ssh_sendbuffer(ssh);
 }
@@ -6102,7 +6092,7 @@ void *new_sock_channel(void *handle, Socket s)
  * This is called when stdout/stderr (the entity to which
  * from_backend sends data) manages to clear some backlog.
  */
-void ssh_unthrottle(void *handle, int bufsize)
+static void ssh_unthrottle(void *handle, int bufsize)
 {
     Ssh ssh = (Ssh) handle;
     if (ssh->version == 1) {