SSH port forwarding! How cool is that?
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 881919c..cbcf439 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -196,6 +196,12 @@ extern void x11_close(Socket);
 extern void x11_send(Socket, char *, int);
 extern void x11_invent_auth(char *, int, char *, int);
 
+extern char *pfd_newconnect(Socket * s, char *hostname, int port, void *c);
+extern char *pfd_addforward(char *desthost, int destport, int port);
+extern void pfd_close(Socket s);
+extern void pfd_send(Socket s, char *data, int len);
+extern void pfd_confirm(Socket s);
+
 /*
  * Ciphers for SSH2. We miss out single-DES because it isn't
  * supported; also 3DES and Blowfish are both done differently from
@@ -263,6 +269,7 @@ enum {                                     /* channel types */
     CHAN_MAINSESSION,
     CHAN_X11,
     CHAN_AGENT,
+    CHAN_SOCKDATA
 };
 
 /*
@@ -286,9 +293,22 @@ struct ssh_channel {
        struct ssh_x11_channel {
            Socket s;
        } x11;
+       struct ssh_pfd_channel {
+           Socket s;
+       } pfd;
     } u;
 };
 
+/*
+ * 2-3-4 tree storing remote->local port forwardings (so we can
+ * reject any attempt to open a port we didn't explicitly ask to
+ * have forwarded).
+ */
+struct ssh_rportfwd {
+    unsigned port;
+    char host[256];
+};
+
 struct Packet {
     long length;
     int type;
@@ -330,6 +350,8 @@ static int ssh_echoing, ssh_editing;
 static tree234 *ssh_channels;         /* indexed by local id */
 static struct ssh_channel *mainchan;   /* primary session channel */
 
+static tree234 *ssh_rportfwds;
+
 static enum {
     SSH_STATE_PREPACKET,
     SSH_STATE_BEFORE_SIZE,
@@ -393,6 +415,18 @@ static int ssh_channelfind(void *av, void *bv)
     return 0;
 }
 
+static int ssh_rportcmp(void *av, void *bv)
+{
+    struct ssh_rportfwd *a = (struct ssh_rportfwd *) av;
+    struct ssh_rportfwd *b = (struct ssh_rportfwd *) bv;
+    int i;
+    if ( (i = strcmp(a->host, b->host)) != 0)
+       return i < 0 ? -1 : +1;
+    if (a->port > b->port)
+       return +1;
+    return 0;
+}
+
 static int alloc_channel_id(void)
 {
     const unsigned CHANNEL_NUMBER_OFFSET = 256;
@@ -1322,7 +1356,7 @@ static void ssh_detect_bugs(char *vstring)
 
 static int do_ssh_init(unsigned char c)
 {
-    static char vslen;
+    static int vslen;
     static char version[10];
     static char *vstring;
     static int vstrsize;
@@ -1478,8 +1512,10 @@ static int ssh_closing(Plug plug, char *error_msg, int error_code,
                       int calling_back)
 {
     ssh_state = SSH_STATE_CLOSED;
-    sk_close(s);
-    s = NULL;
+    if (s) {
+        sk_close(s);
+        s = NULL;
+    }
     if (error_msg) {
        /* A socket error has occurred. */
        connection_fatal(error_msg);
@@ -1574,15 +1610,16 @@ static char *connect_to_host(char *host, int port, char **realhost)
  */
 static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
 {
-    int i, j, len;
-    unsigned char *rsabuf, *keystr1, *keystr2;
+    int i, j;
+    static int len;
+    static unsigned char *rsabuf, *keystr1, *keystr2;
     unsigned char cookie[8];
     struct RSAKey servkey, hostkey;
     struct MD5Context md5c;
     static unsigned long supported_ciphers_mask, supported_auths_mask;
     static int tried_publickey;
     static unsigned char session_id[16];
-    int cipher_type;
+    static int cipher_type;
     static char username[100];
 
     crBegin;
@@ -1783,7 +1820,8 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                            break;
                          default:
                            if (((c >= ' ' && c <= '~') ||
-                                ((unsigned char) c >= 160)) && pos < 40) {
+                                ((unsigned char) c >= 160))
+                               && pos < sizeof(username)-1) {
                                username[pos++] = c;
                                c_write(&c, 1);
                            }
@@ -1883,6 +1921,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                    ssh1_read_bignum(pktin.body, &challenge);
                    {
                        char *agentreq, *q, *ret;
+                       void *vret;
                        int len, retlen;
                        len = 1 + 4;   /* message type, bit count */
                        len += ssh1_bignum_length(key.exponent);
@@ -1902,7 +1941,8 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                        memcpy(q, session_id, 16);
                        q += 16;
                        PUT_32BIT(q, 1);        /* response format */
-                       agent_query(agentreq, len + 4, &ret, &retlen);
+                       agent_query(agentreq, len + 4, &vret, &retlen);
+                       ret = vret;
                        sfree(agentreq);
                        if (ret) {
                            if (ret[4] == SSH1_AGENT_RSA_RESPONSE) {
@@ -2053,9 +2093,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                        exit(0);
                        break;
                      default:
-                       if (((c >= ' ' && c <= '~') ||
-                            ((unsigned char) c >= 160))
-                           && pos < sizeof(password))
+                       if (pos < sizeof(password)-1)
                            password[pos++] = c;
                        break;
                    }
@@ -2259,7 +2297,10 @@ void sshfwd_close(struct ssh_channel *c)
        c->closes = 1;
        if (c->type == CHAN_X11) {
            c->u.x11.s = NULL;
-           logevent("X11 connection terminated");
+           logevent("Forwarded X11 connection terminated");
+       } else if (c->type == CHAN_SOCKDATA) {
+           c->u.pfd.s = NULL;
+           logevent("Forwarded port closed");
        }
     }
 }
@@ -2333,6 +2374,68 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt)
        }
     }
 
+    {
+       char type, *e;
+       int n;
+       int sport,dport;
+       char sports[256], dports[256], host[256];
+       char buf[1024];
+
+       ssh_rportfwds = newtree234(ssh_rportcmp);
+        /* Add port forwardings. */
+       e = cfg.portfwd;
+       while (*e) {
+           type = *e++;
+           n = 0;
+           while (*e && *e != '\t')
+               sports[n++] = *e++;
+           sports[n] = 0;
+           if (*e == '\t')
+               e++;
+           n = 0;
+           while (*e && *e != ':')
+               host[n++] = *e++;
+           host[n] = 0;
+           if (*e == ':')
+               e++;
+           n = 0;
+           while (*e)
+               dports[n++] = *e++;
+           dports[n] = 0;
+           e++;
+           dport = atoi(dports);
+           sport = atoi(sports);
+           if (sport && dport) {
+               if (type == 'L') {
+                   pfd_addforward(host, dport, sport);
+                   sprintf(buf, "Local port %d forwarding to %s:%d",
+                           sport, host, dport);
+                   logevent(buf);
+               } else {
+                   struct ssh_rportfwd *pf;
+                   pf = smalloc(sizeof(*pf));
+                   strcpy(pf->host, host);
+                   pf->port = dport;
+                   if (add234(ssh_rportfwds, pf) != pf) {
+                       sprintf(buf, 
+                               "Duplicate remote port forwarding to %s:%s",
+                               host, dport);
+                       logevent(buf);
+                   } else {
+                       sprintf(buf, "Requesting remote port %d forward to %s:%d",
+                               sport, host, dport);
+                       logevent(buf);
+                       send_packet(SSH1_CMSG_PORT_FORWARD_REQUEST,
+                                   PKT_INT, sport,
+                                   PKT_STR, host,
+                                   PKT_INT, dport,
+                                   PKT_END);
+                   }
+               }
+           }
+       }
+    }
+
     if (!cfg.nopty) {
        send_packet(SSH1_CMSG_REQUEST_PTY,
                    PKT_STR, cfg.termtype,
@@ -2456,6 +2559,73 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt)
                                PKT_INT, c->remoteid, PKT_INT, c->localid,
                                PKT_END);
                }
+           } else if (pktin.type == SSH1_MSG_PORT_OPEN) {
+               /* Remote side is trying to open a channel to talk to a
+                * forwarded port. Give them back a local channel number. */
+               struct ssh_channel *c;
+               struct ssh_rportfwd pf;
+               int hostsize, port;
+               char host[256], buf[1024];
+               char *p, *h, *e;
+               c = smalloc(sizeof(struct ssh_channel));
+
+               hostsize = GET_32BIT(pktin.body+4);
+               for(h = host, p = pktin.body+8; hostsize != 0; hostsize--) {
+                   if (h+1 < host+sizeof(host))
+                       *h++ = *p;
+                   *p++;
+               }
+               *h = 0;
+               port = GET_32BIT(p);
+
+               strcpy(pf.host, host);
+               pf.port = port;
+
+               if (find234(ssh_rportfwds, &pf, NULL) == NULL) {
+                   sprintf(buf, "Rejected remote port open request for %s:%d",
+                           host, port);
+                   logevent(buf);
+                    send_packet(SSH1_MSG_CHANNEL_OPEN_FAILURE,
+                                PKT_INT, GET_32BIT(pktin.body), PKT_END);
+               } else {
+                   sprintf(buf, "Received remote port open request for %s:%d",
+                           host, port);
+                   logevent(buf);
+                   e = pfd_newconnect(&c->u.pfd.s, host, port, c);
+                   if (e != NULL) {
+                       char buf[256];
+                       sprintf(buf, "Port open failed: %s", e);
+                       logevent(buf);
+                       sfree(c);
+                       send_packet(SSH1_MSG_CHANNEL_OPEN_FAILURE,
+                                   PKT_INT, GET_32BIT(pktin.body),
+                                   PKT_END);
+                   } else {
+                       c->remoteid = GET_32BIT(pktin.body);
+                       c->localid = alloc_channel_id();
+                       c->closes = 0;
+                       c->type = CHAN_SOCKDATA;        /* identify channel type */
+                       add234(ssh_channels, c);
+                       send_packet(SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
+                                   PKT_INT, c->remoteid, PKT_INT,
+                                   c->localid, PKT_END);
+                       logevent("Forwarded port opened successfully");
+                   }
+               }
+
+           } else if (pktin.type == SSH1_MSG_CHANNEL_OPEN_CONFIRMATION) {
+                   unsigned int remoteid = GET_32BIT(pktin.body);
+                   unsigned int localid = GET_32BIT(pktin.body+4);
+                   struct ssh_channel *c;
+                   
+                   c = find234(ssh_channels, &remoteid, ssh_channelfind);
+                   if (c) {
+                       c->remoteid = localid;
+                       pfd_confirm(c->u.pfd.s);
+                   } else {
+                       sshfwd_close(c);
+                   }
+
            } else if (pktin.type == SSH1_MSG_CHANNEL_CLOSE ||
                       pktin.type == SSH1_MSG_CHANNEL_CLOSE_CONFIRMATION) {
                /* Remote side closes a channel. */
@@ -2470,11 +2640,17 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt)
                        send_packet(pktin.type, PKT_INT, c->remoteid,
                                    PKT_END);
                    if ((c->closes == 0) && (c->type == CHAN_X11)) {
-                       logevent("X11 connection closed");
+                       logevent("Forwarded X11 connection terminated");
                        assert(c->u.x11.s != NULL);
                        x11_close(c->u.x11.s);
                        c->u.x11.s = NULL;
                    }
+                   if ((c->closes == 0) && (c->type == CHAN_SOCKDATA)) {
+                       logevent("Forwarded port closed");
+                       assert(c->u.pfd.s != NULL);
+                       pfd_close(c->u.pfd.s);
+                       c->u.pfd.s = NULL;
+                   }
                    c->closes |= closetype;
                    if (c->closes == 3) {
                        del234(ssh_channels, c);
@@ -2493,6 +2669,9 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt)
                      case CHAN_X11:
                        x11_send(c->u.x11.s, p, len);
                        break;
+                     case CHAN_SOCKDATA:
+                             pfd_send(c->u.pfd.s, p, len);
+                             break;
                      case CHAN_AGENT:
                        /* Data for an agent message. Buffer it. */
                        while (len > 0) {
@@ -2951,8 +3130,9 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
 #endif
 
     hkey = hostkey->newkey(hostkeydata, hostkeylen);
-    if (!hostkey->verifysig(hkey, sigdata, siglen, exchange_hash, 20)) {
-       bombout(("Server failed host key check"));
+    if (!hkey ||
+       !hostkey->verifysig(hkey, sigdata, siglen, exchange_hash, 20)) {
+       bombout(("Server's host key did not match the signature supplied"));
        crReturn(0);
     }
 
@@ -3086,8 +3266,6 @@ static void ssh2_try_send(struct ssh_channel *c)
  */
 static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
 {
-    static unsigned long remote_winsize;
-    static unsigned long remote_maxpkt;
     static enum {
        AUTH_INVALID, AUTH_PUBLICKEY_AGENT, AUTH_PUBLICKEY_FILE,
        AUTH_PASSWORD
@@ -3197,7 +3375,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                            break;
                          default:
                            if (((c >= ' ' && c <= '~') ||
-                                ((unsigned char) c >= 160)) && pos < 40) {
+                                ((unsigned char) c >= 160))
+                               && pos < sizeof(username)-1) {
                                username[pos++] = c;
                                c_write(&c, 1);
                            }
@@ -3366,6 +3545,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                        static int pklen, alglen, commentlen;
                        static int siglen, retlen, len;
                        static char *q, *agentreq, *ret;
+                       void *vret;
 
                        {
                            char buf[64];
@@ -3444,7 +3624,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                        q += pktout.length - 5;
                        /* And finally the (zero) flags word. */
                        PUT_32BIT(q, 0);
-                       agent_query(agentreq, len + 4, &ret, &retlen);
+                       agent_query(agentreq, len + 4, &vret, &retlen);
+                       ret = vret;
                        sfree(agentreq);
                        if (ret) {
                            if (ret[4] == SSH2_AGENT_SIGN_RESPONSE) {
@@ -3582,9 +3763,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                                exit(0);
                                break;
                              default:
-                               if (((c >= ' ' && c <= '~') ||
-                                    ((unsigned char) c >= 160))
-                                   && pos < 40)
+                               if (pos < sizeof(password)-1)
                                    password[pos++] = c;
                                break;
                            }
@@ -3972,6 +4151,9 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                      case CHAN_X11:
                        x11_send(c->u.x11.s, data, length);
                        break;
+                     case CHAN_SOCKDATA:
+                             pfd_send(c->u.pfd.s, data, length);
+                             break;
                      case CHAN_AGENT:
                        while (length > 0) {
                            if (c->u.a.lensofar < 4) {
@@ -4055,6 +4237,9 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                    sshfwd_close(c);
                } else if (c->type == CHAN_AGENT) {
                    sshfwd_close(c);
+               } else if (c->type == CHAN_SOCKDATA) {
+                       pfd_close(c->u.pfd.s);
+                       sshfwd_close(c);
                }
            } else if (pktin.type == SSH2_MSG_CHANNEL_CLOSE) {
                unsigned i = ssh2_pkt_getuint32();
@@ -4076,6 +4261,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                    break;
                  case CHAN_AGENT:
                    break;
+                 case CHAN_SOCKDATA:
+                         break;
                }
                del234(ssh_channels, c);
                sfree(c->v2.outbuffer);
@@ -4302,6 +4489,39 @@ static void ssh_special(Telnet_Special code)
     }
 }
 
+void *new_sock_channel(Socket s)
+{
+    struct ssh_channel *c;
+    c = smalloc(sizeof(struct ssh_channel));
+
+    if (c) {
+       c->remoteid = GET_32BIT(pktin.body);
+       c->localid = alloc_channel_id();
+       c->closes = 0;
+       c->type = CHAN_SOCKDATA;        /* identify channel type */
+       c->u.pfd.s = s;
+       add234(ssh_channels, c);
+    }
+    return c;
+}
+
+void ssh_send_port_open(void *channel, char *hostname, int port, char *org)
+{
+    struct ssh_channel *c = (struct ssh_channel *)channel;
+    char buf[1024];
+
+    sprintf(buf, "Opening forwarded connection to %.512s:%d", hostname, port);
+    logevent(buf);
+
+    send_packet(SSH1_MSG_PORT_OPEN,
+               PKT_INT, c->localid,
+               PKT_STR, hostname,
+               PKT_INT, port,
+               //PKT_STR, org,
+               PKT_END);
+}
+
+
 static Socket ssh_socket(void)
 {
     return s;
@@ -4330,4 +4550,4 @@ Backend ssh_backend = {
     ssh_sendok,
     ssh_ldisc,
     22
-};
+};
\ No newline at end of file