site.c: Allocate and free the Diffie--Hellman secret buffers each time.
[secnet] / site.c
diff --git a/site.c b/site.c
index 303dbbd..4f954e6 100644 (file)
--- a/site.c
+++ b/site.c
@@ -552,6 +552,8 @@ static _Bool set_new_transform(struct site *st, char *pk)
     _Bool ok;
 
     /* Generate the shared key */
+    assert(!st->sharedsecret);
+    st->sharedsecret = safe_malloc(st->dh->shared_len, "site:sharedsecret");
     if (!st->dh->makeshared(st->dh->st,st->dhsecret,st->dh->secret_len,
                            pk, st->sharedsecret,st->dh->shared_len))
        return False;
@@ -864,6 +866,8 @@ kind##_found:                                                               \
 
 static void generate_dhsecret(struct site *st)
 {
+    assert(!st->dhsecret);
+    st->dhsecret = safe_malloc(st->dh->secret_len, "site:dhsecret");
     st->random->generate(st->random->st, st->dh->secret_len,st->dhsecret);
 }
 
@@ -1515,8 +1519,16 @@ static void enter_state_run(struct site *st)
     FILLZERO(st->localN);
     FILLZERO(st->remoteN);
     dispose_transform(&st->new_transform);
-    memset(st->dhsecret,0,st->dh->secret_len);
-    memset(st->sharedsecret,0,st->dh->shared_len);
+    if (st->dhsecret) {
+       memset(st->dhsecret, 0, st->dh->secret_len);
+       free(st->dhsecret);
+       st->dhsecret = 0;
+    }
+    if (st->sharedsecret) {
+       memset(st->sharedsecret, 0, st->dh->shared_len);
+       free(st->sharedsecret);
+       st->sharedsecret = 0;
+    }
     set_link_quality(st);
 
     if (st->keepalive && !current_valid(st))
@@ -2252,9 +2264,8 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     st->auxiliary_key.key_timeout=0;
     transport_peers_clear(st,&st->peers);
     transport_peers_clear(st,&st->setup_peers);
-    /* XXX mlock these */
-    st->dhsecret=safe_malloc(st->dh->secret_len,"site:dhsecret");
-    st->sharedsecret=safe_malloc(st->dh->shared_len, "site:sharedsecret");
+    st->dhsecret=0;
+    st->sharedsecret=0;
 
 #define SET_CAPBIT(bit) do {                                           \
     uint32_t capflag = 1UL << (bit);                                   \