site: Move current_transform, _key_timeout and remote_session_id into a struct
[secnet] / site.c
diff --git a/site.c b/site.c
index c0f4387..1bf8dec 100644 (file)
--- a/site.c
+++ b/site.c
@@ -214,6 +214,12 @@ static void transport_xmit(struct site *st, transport_peers *peers,
  /***** END of transport peers declarations *****/
 
 
+struct data_key {
+    struct transform_inst_if *transform;
+    uint64_t key_timeout; /* End of life of current key */
+    uint32_t remote_session_id;
+};
+
 struct site {
     closure_t cl;
     struct site_if ops;
@@ -259,10 +265,7 @@ struct site {
     uint64_t now; /* Most recently seen time */
 
     /* The currently established session */
-    uint32_t remote_session_id;
-    struct transform_inst_if *current_transform;
-    bool_t current_valid;
-    uint64_t current_key_timeout; /* End of life of current key */
+    struct data_key current;
     uint64_t renegotiate_key_time; /* When we can negotiate a new key */
     transport_peers peers; /* Current address(es) of peer for data traffic */
 
@@ -324,6 +327,11 @@ static bool_t enter_new_state(struct site *st,uint32_t next);
 static void enter_state_wait(struct site *st);
 static void activate_new_key(struct site *st);
 
+static bool_t current_valid(struct site *st)
+{
+    return st->current.transform->valid(st->current.transform->st);
+}
+
 #define CHECK_AVAIL(b,l) do { if ((b)->size<(l)) return False; } while(0)
 #define CHECK_EMPTY(b) do { if ((b)->size!=0) return False; } while(0)
 #define CHECK_TYPE(b,t) do { uint32_t type; \
@@ -642,15 +650,15 @@ static bool_t generate_msg5(struct site *st)
 }
 
 static bool_t process_msg5(struct site *st, struct buffer_if *msg5,
-                          const struct comm_addr *src)
+                          const struct comm_addr *src,
+                          struct transform_inst_if *transform)
 {
     struct msg0 m;
     cstring_t transform_err;
 
     if (!unpick_msg0(st,msg5,&m)) return False;
 
-    if (st->new_transform->reverse(st->new_transform->st,
-                                  msg5,&transform_err)) {
+    if (transform->reverse(transform->st,msg5,&transform_err)) {
        /* There's a problem */
        slog(st,LOG_SEC,"process_msg5: transform: %s",transform_err);
        return False;
@@ -666,7 +674,8 @@ static bool_t process_msg5(struct site *st, struct buffer_if *msg5,
     return True;
 }
 
-static bool_t generate_msg6(struct site *st)
+static void create_msg6(struct site *st, struct transform_inst_if *transform,
+                       uint32_t session_id)
 {
     cstring_t transform_err;
 
@@ -676,12 +685,15 @@ static bool_t generate_msg6(struct site *st)
     /* Give the netlink code an opportunity to put its own stuff in the
        message (configuration information, etc.) */
     buf_prepend_uint32(&st->buffer,LABEL_MSG6);
-    st->new_transform->forwards(st->new_transform->st,&st->buffer,
-                               &transform_err);
+    transform->forwards(transform->st,&st->buffer,&transform_err);
     buf_prepend_uint32(&st->buffer,LABEL_MSG6);
     buf_prepend_uint32(&st->buffer,st->index);
-    buf_prepend_uint32(&st->buffer,st->setup_session_id);
+    buf_prepend_uint32(&st->buffer,session_id);
+}
 
+static bool_t generate_msg6(struct site *st)
+{
+    create_msg6(st,st->new_transform,st->setup_session_id);
     st->retries=1; /* Peer will retransmit MSG5 if this packet gets lost */
     return True;
 }
@@ -722,7 +734,7 @@ static bool_t decrypt_msg0(struct site *st, struct buffer_if *msg0)
     /* Keep a copy so we can try decrypting it with multiple keys */
     buffer_copy(&st->scratch, msg0);
 
-    problem = st->current_transform->reverse(st->current_transform->st,
+    problem = st->current.transform->reverse(st->current.transform->st,
                                             msg0,&transform_err);
     if (!problem) return True;
 
@@ -863,17 +875,16 @@ static void activate_new_key(struct site *st)
 
     /* We have two transform instances, which we swap between active
        and setup */
-    t=st->current_transform;
-    st->current_transform=st->new_transform;
+    t=st->current.transform;
+    st->current.transform=st->new_transform;
     st->new_transform=t;
 
     t->delkey(t->st);
     st->timeout=0;
-    st->current_valid=True;
-    st->current_key_timeout=st->now+st->key_lifetime;
+    st->current.key_timeout=st->now+st->key_lifetime;
     st->renegotiate_key_time=st->now+st->key_renegotiate_time;
     transport_peers_copy(st,&st->peers,&st->setup_peers);
-    st->remote_session_id=st->setup_session_id;
+    st->current.remote_session_id=st->setup_session_id;
 
     slog(st,LOG_ACTIVATE_KEY,"new key activated");
     enter_state_run(st);
@@ -881,12 +892,11 @@ static void activate_new_key(struct site *st)
 
 static void delete_key(struct site *st, cstring_t reason, uint32_t loglevel)
 {
-    if (st->current_valid) {
+    if (current_valid(st)) {
        slog(st,loglevel,"session closed (%s)",reason);
 
-       st->current_valid=False;
-       st->current_transform->delkey(st->current_transform->st);
-       st->current_key_timeout=0;
+       st->current.transform->delkey(st->current.transform->st);
+       st->current.key_timeout=0;
        set_link_quality(st);
     }
 }
@@ -907,7 +917,7 @@ static void enter_state_stop(struct site *st)
 static void set_link_quality(struct site *st)
 {
     uint32_t quality;
-    if (st->current_valid)
+    if (current_valid(st))
        quality=LINK_QUALITY_UP;
     else if (st->state==SITE_WAIT || st->state==SITE_STOP)
        quality=LINK_QUALITY_DOWN;
@@ -1019,17 +1029,17 @@ static bool_t send_msg7(struct site *st, cstring_t reason)
 {
     cstring_t transform_err;
 
-    if (st->current_valid && st->buffer.free
+    if (current_valid(st) && st->buffer.free
        && transport_peers_valid(&st->peers)) {
        BUF_ALLOC(&st->buffer,"site:MSG7");
        buffer_init(&st->buffer,st->transform->max_start_pad+(4*3));
        buf_append_uint32(&st->buffer,LABEL_MSG7);
        buf_append_string(&st->buffer,reason);
-       st->current_transform->forwards(st->current_transform->st,
+       st->current.transform->forwards(st->current.transform->st,
                                        &st->buffer, &transform_err);
        buf_prepend_uint32(&st->buffer,LABEL_MSG0);
        buf_prepend_uint32(&st->buffer,st->index);
-       buf_prepend_uint32(&st->buffer,st->remote_session_id);
+       buf_prepend_uint32(&st->buffer,st->current.remote_session_id);
        transport_xmit(st,&st->peers,&st->buffer,True);
        BUF_FREE(&st->buffer);
        return True;
@@ -1070,10 +1080,10 @@ static int site_beforepoll(void *sst, struct pollfd *fds, int *nfds_io,
     st->now=*now;
 
     /* Work out when our next timeout is. The earlier of 'timeout' or
-       'current_key_timeout'. A stored value of '0' indicates no timeout
+       'current.key_timeout'. A stored value of '0' indicates no timeout
        active. */
     site_settimeout(st->timeout, timeout_io);
-    site_settimeout(st->current_key_timeout, timeout_io);
+    site_settimeout(st->current.key_timeout, timeout_io);
 
     return 0; /* success */
 }
@@ -1096,7 +1106,7 @@ static void site_afterpoll(void *sst, struct pollfd *fds, int nfds)
                 st->state);
        }
     }
-    if (st->current_key_timeout && *now>st->current_key_timeout) {
+    if (st->current.key_timeout && *now>st->current.key_timeout) {
        delete_key(st,"maximum key life exceeded",LOG_TIMEOUT_KEY);
     }
 }
@@ -1116,15 +1126,15 @@ static void site_outgoing(void *sst, struct buffer_if *buf)
 
     /* In all other states we consider delivering the packet if we have
        a valid key and a valid address to send it to. */
-    if (st->current_valid && transport_peers_valid(&st->peers)) {
+    if (current_valid(st) && transport_peers_valid(&st->peers)) {
        /* Transform it and send it */
        if (buf->size>0) {
            buf_prepend_uint32(buf,LABEL_MSG9);
-           st->current_transform->forwards(st->current_transform->st,
+           st->current.transform->forwards(st->current.transform->st,
                                            buf, &transform_err);
            buf_prepend_uint32(buf,LABEL_MSG0);
            buf_prepend_uint32(buf,st->index);
-           buf_prepend_uint32(buf,st->remote_session_id);
+           buf_prepend_uint32(buf,st->current.remote_session_id);
            transport_xmit(st,&st->peers,buf,False);
        }
        BUF_FREE(buf);
@@ -1204,7 +1214,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
        case 0: /* NAK */
            /* If the source is our current peer then initiate a key setup,
               because our peer's forgotten the key */
-           if (get_uint32(buf->start+4)==st->remote_session_id) {
+           if (get_uint32(buf->start+4)==st->current.remote_session_id) {
                initiate_key_setup(st,"received a NAK");
            } else {
                slog(st,LOG_SEC,"bad incoming NAK");
@@ -1258,13 +1268,26 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
               case we discard it. The peer will realise that we
               are using the new key when they see our data packets.
               Until then the peer's data packets to us get discarded. */
-           if (st->state!=SITE_SENTMSG4) {
-               slog(st,LOG_UNEXPECTED,"unexpected MSG5");
-           } else if (process_msg5(st,buf,source)) {
-               transport_setup_msgok(st,source);
-               enter_new_state(st,SITE_RUN);
+           if (st->state==SITE_SENTMSG4) {
+               if (process_msg5(st,buf,source,st->new_transform)) {
+                   transport_setup_msgok(st,source);
+                   enter_new_state(st,SITE_RUN);
+               } else {
+                   slog(st,LOG_SEC,"invalid MSG5");
+               }
+           } else if (st->state==SITE_RUN) {
+               if (process_msg5(st,buf,source,st->current.transform)) {
+                   slog(st,LOG_DROP,"got MSG5, retransmitting MSG6");
+                   transport_setup_msgok(st,source);
+                   create_msg6(st,st->current.transform,
+                               st->current.remote_session_id);
+                   transport_xmit(st,&st->peers,&st->buffer,True);
+                   BUF_FREE(&st->buffer);
+               } else {
+                   slog(st,LOG_SEC,"invalid MSG5 (in state RUN)");
+               }
            } else {
-               slog(st,LOG_SEC,"invalid MSG5");
+               slog(st,LOG_UNEXPECTED,"unexpected MSG5");
            }
            break;
        case LABEL_MSG6:
@@ -1454,8 +1477,7 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     register_for_poll(st, site_beforepoll, site_afterpoll, 0, "site");
     st->timeout=0;
 
-    st->current_valid=False;
-    st->current_key_timeout=0;
+    st->current.key_timeout=0;
     transport_peers_clear(st,&st->peers);
     transport_peers_clear(st,&st->setup_peers);
     /* XXX mlock these */
@@ -1482,7 +1504,7 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     for (i=0; i<st->ncomms; i++)
        st->comms[i]->request_notify(st->comms[i]->st, st, site_incoming);
 
-    st->current_transform=st->transform->create(st->transform->st);
+    st->current.transform=st->transform->create(st->transform->st);
     st->new_transform=st->transform->create(st->transform->st);
 
     enter_state_stop(st);