Dave Hinton's modifications to the network layer interface, which
[u/mdw/putty] / winnet.c
index 8602802..01b4f1b 100644 (file)
--- a/winnet.c
+++ b/winnet.c
@@ -49,6 +49,7 @@
 #include <stdio.h>
 #include <stdlib.h>
 
+#define DEFINE_PLUG_METHOD_MACROS
 #include "putty.h"
 #include "network.h"
 #include "tree234.h"
 #define BUFFER_GRANULE 512
 
 struct Socket_tag {
+    struct socket_function_table *fn;
+    /* the above variable absolutely *must* be the first in this structure */
     char *error;
     SOCKET s;
-    sk_receiver_t receiver;
+    Plug plug;
     void *private_ptr;
     struct buffer *head, *tail;
     int writable;
     int sending_oob;
+    int oobinline;
 };
 
+/*
+ * We used to typedef struct Socket_tag *Socket.
+ *
+ * Since we have made the networking abstraction slightly more
+ * abstract, Socket no longer means a tcp socket (it could mean
+ * an ssl socket).  So now we must use Actual_Socket when we know
+ * we are talking about a tcp socket.
+ */
+typedef struct Socket_tag *Actual_Socket;
+
 struct SockAddr_tag {
     char *error;
     /* address family this belongs to, AF_INET for IPv4, AF_INET6 for IPv6. */
@@ -89,7 +103,7 @@ struct buffer {
 static tree234 *sktree;
 
 static int cmpfortree(void *av, void *bv) {
-    Socket a = (Socket)av, b = (Socket)bv;
+    Actual_Socket a = (Actual_Socket)av, b = (Actual_Socket)bv;
     unsigned long as = (unsigned long)a->s, bs = (unsigned long)b->s;
     if (as < bs) return -1;
     if (as > bs) return +1;
@@ -97,7 +111,7 @@ static int cmpfortree(void *av, void *bv) {
 }
 
 static int cmpforsearch(void *av, void *bv) {
-    Socket b = (Socket)bv;
+    Actual_Socket b = (Actual_Socket)bv;
     unsigned long as = (unsigned long)av, bs = (unsigned long)b->s;
     if (as < bs) return -1;
     if (as > bs) return +1;
@@ -288,7 +302,37 @@ void sk_addr_free(SockAddr addr) {
     sfree(addr);
 }
 
-Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
+static Plug sk_tcp_plug (Socket sock, Plug p) {
+    Actual_Socket s = (Actual_Socket) sock;
+    Plug ret = s->plug;
+    if (p) s->plug = p;
+    return ret;
+}
+
+static void sk_tcp_flush (Socket s) {
+    /*
+     * We send data to the socket as soon as we can anyway,
+     * so we don't need to do anything here.  :-)
+     */
+}
+
+void sk_tcp_close (Socket s);
+void sk_tcp_write (Socket s, char *data, int len);
+void sk_tcp_write_oob (Socket s, char *data, int len);
+char *sk_tcp_socket_error(Socket s);
+
+Socket sk_new(SockAddr addr, int port, int privport, int oobinline,
+              Plug plug)
+{
+    static struct socket_function_table fn_table = {
+       sk_tcp_plug,
+       sk_tcp_close,
+       sk_tcp_write,
+       sk_tcp_write_oob,
+       sk_tcp_flush,
+       sk_tcp_socket_error
+    };
+
     SOCKET s;
 #ifdef IPV6
     SOCKADDR_IN6 a6;
@@ -296,7 +340,7 @@ Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
     SOCKADDR_IN a;
     DWORD err;
     char *errstr;
-    Socket ret;
+    Actual_Socket ret;
     extern char *do_select(SOCKET skt, int startup);
     short localport;
 
@@ -304,8 +348,9 @@ Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
      * Create Socket structure.
      */
     ret = smalloc(sizeof(struct Socket_tag));
+    ret->fn = &fn_table;
     ret->error = NULL;
-    ret->receiver = receiver;
+    ret->plug = plug;
     ret->head = ret->tail = NULL;
     ret->writable = 1;                /* to start with */
     ret->sending_oob = 0;
@@ -319,9 +364,11 @@ Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
     if (s == INVALID_SOCKET) {
        err = WSAGetLastError();
         ret->error = winsock_error_string(err);
-       return ret;
+       return (Socket) ret;
     }
-    {
+
+    ret->oobinline = oobinline;
+    if (oobinline) {
        BOOL b = TRUE;
        setsockopt (s, SOL_SOCKET, SO_OOBINLINE, (void *)&b, sizeof(b));
     }
@@ -380,7 +427,7 @@ Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
     if (err)
     {
        ret->error = winsock_error_string(err);
-       return ret;
+       return (Socket) ret;
     }
 
     /*
@@ -409,7 +456,7 @@ Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
     {
        err = WSAGetLastError();
        ret->error = winsock_error_string(err);
-       return ret;
+       return (Socket) ret;
     }
 
     /* Set up a select mechanism. This could be an AsyncSelect on a
@@ -417,16 +464,17 @@ Socket sk_new(SockAddr addr, int port, int privport, sk_receiver_t receiver) {
     errstr = do_select(s, 1);
     if (errstr) {
        ret->error = errstr;
-       return ret;
+       return (Socket) ret;
     }
 
     add234(sktree, ret);
 
-    return ret;
+    return (Socket) ret;
 }
 
-void sk_close(Socket s) {
+static void sk_tcp_close(Socket sock) {
     extern char *do_select(SOCKET skt, int startup);
+    Actual_Socket s = (Actual_Socket) sock;
 
     del234(sktree, s);
     do_select(s->s, 0);
@@ -438,7 +486,7 @@ void sk_close(Socket s) {
  * The function which tries to send on a socket once it's deemed
  * writable.
  */
-void try_send(Socket s) {
+void try_send(Actual_Socket s) {
     while (s->head) {
        int nsent;
        DWORD err;
@@ -493,7 +541,9 @@ void try_send(Socket s) {
     }
 }
 
-void sk_write(Socket s, char *buf, int len) {
+static void sk_tcp_write(Socket sock, char *buf, int len) {
+    Actual_Socket s = (Actual_Socket) sock;
+
     /*
      * Add the data to the buffer list on the socket.
      */
@@ -528,7 +578,9 @@ void sk_write(Socket s, char *buf, int len) {
        try_send(s);
 }
 
-void sk_write_oob(Socket s, char *buf, int len) {
+static void sk_tcp_write_oob(Socket sock, char *buf, int len) {
+    Actual_Socket s = (Actual_Socket) sock;
+
     /*
      * Replace the buffer list on the socket with the data.
      */
@@ -560,10 +612,10 @@ void sk_write_oob(Socket s, char *buf, int len) {
 }
 
 int select_result(WPARAM wParam, LPARAM lParam) {
-    int ret;
+    int ret, open;
     DWORD err;
     char buf[BUFFER_GRANULE];
-    Socket s;
+    Actual_Socket s;
     u_long atmark;
 
     /* wParam is the socket itself */
@@ -574,22 +626,36 @@ int select_result(WPARAM wParam, LPARAM lParam) {
     if ((err = WSAGETSELECTERROR(lParam)) != 0) {
         /*
          * An error has occurred on this socket. Pass it to the
-         * receiver function.
+         * plug.
          */
-        return s->receiver(s, 3, winsock_error_string(err), err);
+        return plug_closing (s->plug, winsock_error_string(err), err, 0);
     }
 
     noise_ultralight(lParam);
 
     switch (WSAGETSELECTEVENT(lParam)) {
       case FD_READ:
-        atmark = 1;
-        /* Some WinSock wrappers don't support this call, so we
-         * deliberately don't check the return value. If the call
-         * fails and does nothing, we will get back atmark==1,
-         * which is good enough to keep going at least. */
-        ioctlsocket(s->s, SIOCATMARK, &atmark);
+        /*
+         * We have received data on the socket. For an oobinline
+         * socket, this might be data _before_ an urgent pointer,
+         * in which case we send it to the back end with type==1
+         * (data prior to urgent).
+         */
+        if (s->oobinline) {
+            atmark = 1;
+            ioctlsocket(s->s, SIOCATMARK, &atmark);
+            /*
+             * Avoid checking the return value from ioctlsocket(),
+             * on the grounds that some WinSock wrappers don't
+             * support it. If it does nothing, we get atmark==1,
+             * which is equivalent to `no OOB pending', so the
+             * effect will be to non-OOB-ify any OOB data.
+             */
+        } else
+            atmark = 1;
+
        ret = recv(s->s, buf, sizeof(buf), 0);
+        noise_ultralight(ret);
        if (ret < 0) {
            err = WSAGetLastError();
            if (err == WSAEWOULDBLOCK) {
@@ -597,28 +663,27 @@ int select_result(WPARAM wParam, LPARAM lParam) {
            }
        }
        if (ret < 0) {
-           return s->receiver(s, 3, winsock_error_string(err), err);
+           return plug_closing (s->plug, winsock_error_string(err), err, 0);
+       } else if (0 == ret) {
+           return plug_closing (s->plug, NULL, 0, 0);
        } else {
-            int type = 0;
-           if (atmark==0) {
-               ioctlsocket(s->s, SIOCATMARK, &atmark);
-               if(atmark) type = 2; else type = 1;
-           }
-           return s->receiver(s, type, buf, ret);
+           return plug_receive (s->plug, atmark ? 0 : 1, buf, ret);
        }
        break;
       case FD_OOB:
-       /*
-        * Read all data up to the OOB marker, and send it to the
-        * receiver with urgent==1 (OOB pending).
-        */
+        /*
+         * This will only happen on a non-oobinline socket. It
+         * indicates that we can immediately perform an OOB read
+         * and get back OOB data, which we will send to the back
+         * end with type==2 (urgent data).
+         */
         ret = recv(s->s, buf, sizeof(buf), MSG_OOB);
         noise_ultralight(ret);
         if (ret <= 0) {
             fatalbox(ret == 0 ? "Internal networking trouble" :
                      winsock_error_string(WSAGetLastError()));
         } else {
-            return s->receiver(s, 2, buf, ret);
+           return plug_receive (s->plug, 2, buf, ret);
         }
         break;
       case FD_WRITE:
@@ -626,9 +691,21 @@ int select_result(WPARAM wParam, LPARAM lParam) {
        try_send(s);
        break;
       case FD_CLOSE:
-       /* Signal a close on the socket. */
-       return s->receiver(s, 0, NULL, 0);
-       break;
+       /* Signal a close on the socket. First read any outstanding data. */
+        open = 1;
+        do {
+            ret = recv(s->s, buf, sizeof(buf), 0);
+            if (ret < 0) {
+                err = WSAGetLastError();
+                if (err == WSAEWOULDBLOCK)
+                    break;
+                return plug_closing (s->plug, winsock_error_string(err), err, 0);
+            } else {
+               if (ret) open &= plug_receive (s->plug, 0, buf, ret);
+               else open &= plug_closing (s->plug, NULL, 0, 0);
+           }
+       } while (ret > 0);
+        return open;
     }
 
     return 1;
@@ -638,10 +715,12 @@ int select_result(WPARAM wParam, LPARAM lParam) {
  * Each socket abstraction contains a `void *' private field in
  * which the client can keep state.
  */
-void sk_set_private_ptr(Socket s, void *ptr) {
+void sk_set_private_ptr(Socket sock, void *ptr) {
+    Actual_Socket s = (Actual_Socket) sock;
     s->private_ptr = ptr;
 }
-void *sk_get_private_ptr(Socket s) {
+void *sk_get_private_ptr(Socket sock) {
+    Actual_Socket s = (Actual_Socket) sock;
     return s->private_ptr;
 }
 
@@ -653,7 +732,8 @@ void *sk_get_private_ptr(Socket s) {
 char *sk_addr_error(SockAddr addr) {
     return addr->error;
 }
-char *sk_socket_error(Socket s) {
+static char *sk_tcp_socket_error(Socket sock) {
+    Actual_Socket s = (Actual_Socket) sock;
     return s->error;
 }
 
@@ -661,10 +741,10 @@ char *sk_socket_error(Socket s) {
  * For Plink: enumerate all sockets currently active.
  */
 SOCKET first_socket(enum234 *e) {
-    Socket s = first234(sktree, e);
+    Actual_Socket s = first234(sktree, e);
     return s ? s->s : INVALID_SOCKET;
 }
 SOCKET next_socket(enum234 *e) {
-    Socket s = next234(e);
+    Actual_Socket s = next234(e);
     return s ? s->s : INVALID_SOCKET;
 }