Improve SSH2 host key abstraction into a generic `signing key'
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 26379e1..92afd08 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -171,8 +171,8 @@ const static struct ssh_cipher *ciphers[] = { &ssh_blowfish_ssh2, &ssh_3des_ssh2
 extern const struct ssh_kex ssh_diffiehellman;
 const static struct ssh_kex *kex_algs[] = { &ssh_diffiehellman };
 
-extern const struct ssh_hostkey ssh_dss;
-const static struct ssh_hostkey *hostkey_algs[] = { &ssh_dss };
+extern const struct ssh_signkey ssh_dss;
+const static struct ssh_signkey *hostkey_algs[] = { &ssh_dss };
 
 extern const struct ssh_mac ssh_md5, ssh_sha1, ssh_sha1_buggy;
 
@@ -237,6 +237,7 @@ static Socket s = NULL;
 
 static unsigned char session_key[32];
 static int ssh1_compressing;
+static int ssh_agentfwd_enabled;
 static const struct ssh_cipher *cipher = NULL;
 static const struct ssh_cipher *cscipher = NULL;
 static const struct ssh_cipher *sccipher = NULL;
@@ -245,7 +246,7 @@ static const struct ssh_mac *scmac = NULL;
 static const struct ssh_compress *cscomp = NULL;
 static const struct ssh_compress *sccomp = NULL;
 static const struct ssh_kex *kex = NULL;
-static const struct ssh_hostkey *hostkey = NULL;
+static const struct ssh_signkey *hostkey = NULL;
 int (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL;
 
 static char *savedhost;
@@ -1024,6 +1025,7 @@ static int do_ssh_init(unsigned char c) {
            break;
     }
 
+    ssh_agentfwd_enabled = FALSE;
     rdpkt2_state.incoming_sequence = 0;
 
     *vsp = 0;
@@ -1732,8 +1734,10 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
             crReturnV;
         } else if (pktin.type == SSH1_SMSG_FAILURE) {
             logevent("Agent forwarding refused");
-        } else
+        } else {
             logevent("Agent forwarding enabled");
+           ssh_agentfwd_enabled = TRUE;
+       }
     }
 
     if (!cfg.nopty) {
@@ -1797,26 +1801,35 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
             } else if (pktin.type == SSH1_SMSG_AGENT_OPEN) {
                 /* Remote side is trying to open a channel to talk to our
                  * agent. Give them back a local channel number. */
-                unsigned i = 1;
+                unsigned i;
                 struct ssh_channel *c;
                 enum234 e;
-                for (c = first234(ssh_channels, &e); c; c = next234(&e)) {
-                    if (c->localid > i)
-                        break;         /* found a free number */
-                    i = c->localid + 1;
-                }
-                c = malloc(sizeof(struct ssh_channel));
-                c->remoteid = GET_32BIT(pktin.body);
-                c->localid = i;
-                c->closes = 0;
-                c->type = SSH1_SMSG_AGENT_OPEN;   /* identify channel type */
-                c->u.a.lensofar = 0;
-                add234(ssh_channels, c);
-                send_packet(SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
-                            PKT_INT, c->remoteid, PKT_INT, c->localid,
-                            PKT_END);
-            } else if (pktin.type == SSH1_MSG_CHANNEL_CLOSE ||
-                       pktin.type == SSH1_MSG_CHANNEL_CLOSE_CONFIRMATION) {
+
+               /* Refuse if agent forwarding is disabled. */
+               if (!ssh_agentfwd_enabled) {
+                   send_packet(SSH1_MSG_CHANNEL_OPEN_FAILURE,
+                               PKT_INT, GET_32BIT(pktin.body),
+                               PKT_END);
+               } else {
+                   i = 1;
+                   for (c = first234(ssh_channels, &e); c; c = next234(&e)) {
+                       if (c->localid > i)
+                           break;     /* found a free number */
+                       i = c->localid + 1;
+                   }
+                   c = malloc(sizeof(struct ssh_channel));
+                   c->remoteid = GET_32BIT(pktin.body);
+                   c->localid = i;
+                   c->closes = 0;
+                   c->type = SSH1_SMSG_AGENT_OPEN;/* identify channel type */
+                   c->u.a.lensofar = 0;
+                   add234(ssh_channels, c);
+                   send_packet(SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
+                               PKT_INT, c->remoteid, PKT_INT, c->localid,
+                               PKT_END);
+               }
+           } else if (pktin.type == SSH1_MSG_CHANNEL_CLOSE ||
+                      pktin.type == SSH1_MSG_CHANNEL_CLOSE_CONFIRMATION) {
                 /* Remote side closes a channel. */
                 unsigned i = GET_32BIT(pktin.body);
                 struct ssh_channel *c;
@@ -1974,6 +1987,7 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     static const struct ssh_compress *sccomp_tobe = NULL;
     static char *hostkeydata, *sigdata, *keystr, *fingerprint;
     static int hostkeylen, siglen;
+    static void *hkey;                /* actual host key */
     static unsigned char exchange_hash[20];
     static unsigned char keyspace[40];
     static const struct ssh_cipher *preferred_cipher;
@@ -2203,8 +2217,8 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     debug(("\r\n"));
 #endif
 
-    hostkey->setkey(hostkeydata, hostkeylen);
-    if (!hostkey->verifysig(sigdata, siglen, exchange_hash, 20)) {
+    hkey = hostkey->newkey(hostkeydata, hostkeylen);
+    if (!hostkey->verifysig(hkey, sigdata, siglen, exchange_hash, 20)) {
         bombout(("Server failed host key check"));
         crReturn(0);
     }
@@ -2222,14 +2236,15 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
      * Authenticate remote host: verify host key. (We've already
      * checked the signature of the exchange hash.)
      */
-    keystr = hostkey->fmtkey();
-    fingerprint = hostkey->fingerprint();
+    keystr = hostkey->fmtkey(hkey);
+    fingerprint = hostkey->fingerprint(hkey);
     verify_ssh_host_key(savedhost, savedport, hostkey->keytype,
                         keystr, fingerprint);
     logevent("Host key fingerprint is:");
     logevent(fingerprint);
     free(fingerprint);
     free(keystr);
+    hostkey->freekey(hkey);
 
     /*
      * Send SSH2_MSG_NEWKEYS.
@@ -2680,7 +2695,7 @@ static char *ssh_init (char *host, int port, char **realhost) {
  * Called to send data down the Telnet connection.
  */
 static void ssh_send (char *buf, int len) {
-    if (s == NULL)
+    if (s == NULL || ssh_protocol == NULL)
        return;
 
     ssh_protocol(buf, len, 0);