Cleanups of the GSSAPI support. On Windows, standard GSS libraries
[u/mdw/putty] / windows / wingss.c
index 90cd24e..9084c3c 100644 (file)
 
 /* Windows code to set up the GSSAPI library list. */
 
-struct ssh_gss_library ssh_gss_libraries[2];
-int n_ssh_gss_libraries = 0;
-static int initialised = FALSE;
-
-const int ngsslibs = 2;
-const char *const gsslibnames[2] = {
-    "GSSAPI32.DLL (MIT Kerberos)",
-    "SECUR32.DLL (Microsoft SSPI)",
+const int ngsslibs = 3;
+const char *const gsslibnames[3] = {
+    "MIT Kerberos GSSAPI32.DLL",
+    "Microsoft SSPI SECUR32.DLL",
+    "User-specified GSSAPI DLL",
 };
 const struct keyval gsslibkeywords[] = {
     { "gssapi32", 0 },
     { "sspi", 1 },
+    { "custom", 2 },
 };
 
 DECL_WINDOWS_FUNCTION(static, SECURITY_STATUS,
@@ -67,22 +65,46 @@ const char *gsslogmsg = NULL;
 
 static void ssh_sspi_bind_fns(struct ssh_gss_library *lib);
 
-void ssh_gss_init(void)
+struct ssh_gss_liblist *ssh_gss_setup(const Config *cfg)
 {
     HMODULE module;
+    HKEY regkey;
+    struct ssh_gss_liblist *list = snew(struct ssh_gss_liblist);
 
-    if (initialised) return;
-    initialised = TRUE;
+    list->libraries = snewn(3, struct ssh_gss_library);
+    list->nlibraries = 0;
 
     /* MIT Kerberos GSSAPI implementation */
     /* TODO: For 64-bit builds, check for gssapi64.dll */
-    module = LoadLibrary("gssapi32.dll");
+    module = NULL;
+    if (RegOpenKey(HKEY_LOCAL_MACHINE, "SOFTWARE\\MIT\\Kerberos", &regkey)
+       == ERROR_SUCCESS) {
+       DWORD type, size;
+       LONG ret;
+       char *buffer;
+
+       /* Find out the string length */
+        ret = RegQueryValueEx(regkey, "InstallDir", NULL, &type, NULL, &size);
+
+       if (ret == ERROR_SUCCESS && type == REG_SZ) {
+           buffer = snewn(size + 20, char);
+           ret = RegQueryValueEx(regkey, "InstallDir", NULL,
+                                 &type, buffer, &size);
+           if (ret == ERROR_SUCCESS && type == REG_SZ) {
+               strcat(buffer, "\\bin\\gssapi32.dll");
+               module = LoadLibrary(buffer);
+           }
+           sfree(buffer);
+       }
+       RegCloseKey(regkey);
+    }
     if (module) {
        struct ssh_gss_library *lib =
-           &ssh_gss_libraries[n_ssh_gss_libraries++];
+           &list->libraries[list->nlibraries++];
 
        lib->id = 0;
        lib->gsslogmsg = "Using GSSAPI from GSSAPI32.DLL";
+       lib->handle = (void *)module;
 
 #define BIND_GSS_FN(name) \
     lib->u.gssapi.name = (t_gss_##name) GetProcAddress(module, "gss_" #name)
@@ -105,10 +127,11 @@ void ssh_gss_init(void)
     module = load_system32_dll("secur32.dll");
     if (module) {
        struct ssh_gss_library *lib =
-           &ssh_gss_libraries[n_ssh_gss_libraries++];
+           &list->libraries[list->nlibraries++];
 
        lib->id = 1;
        lib->gsslogmsg = "Using SSPI from SECUR32.DLL";
+       lib->handle = (void *)module;
 
        GET_WINDOWS_FUNCTION(module, AcquireCredentialsHandleA);
        GET_WINDOWS_FUNCTION(module, InitializeSecurityContextA);
@@ -120,6 +143,67 @@ void ssh_gss_init(void)
 
        ssh_sspi_bind_fns(lib);
     }
+
+    /*
+     * Custom GSSAPI DLL.
+     */
+    module = NULL;
+    if (cfg->ssh_gss_custom.path[0]) {
+       module = LoadLibrary(cfg->ssh_gss_custom.path);
+    }
+    if (module) {
+       struct ssh_gss_library *lib =
+           &list->libraries[list->nlibraries++];
+
+       lib->id = 2;
+       lib->gsslogmsg = dupprintf("Using GSSAPI from user-specified"
+                                  " library '%s'", cfg->ssh_gss_custom.path);
+       lib->handle = (void *)module;
+
+#define BIND_GSS_FN(name) \
+    lib->u.gssapi.name = (t_gss_##name) GetProcAddress(module, "gss_" #name)
+
+        BIND_GSS_FN(delete_sec_context);
+        BIND_GSS_FN(display_status);
+        BIND_GSS_FN(get_mic);
+        BIND_GSS_FN(import_name);
+        BIND_GSS_FN(init_sec_context);
+        BIND_GSS_FN(release_buffer);
+        BIND_GSS_FN(release_cred);
+        BIND_GSS_FN(release_name);
+
+#undef BIND_GSS_FN
+
+        ssh_gssapi_bind_fns(lib);
+    }
+
+
+    return list;
+}
+
+void ssh_gss_cleanup(struct ssh_gss_liblist *list)
+{
+    int i;
+
+    /*
+     * LoadLibrary and FreeLibrary are defined to employ reference
+     * counting in the case where the same library is repeatedly
+     * loaded, so even in a multiple-sessions-per-process context
+     * (not that we currently expect ever to have such a thing on
+     * Windows) it's safe to naively FreeLibrary everything here
+     * without worrying about destroying it under the feet of
+     * another SSH instance still using it.
+     */
+    for (i = 0; i < list->nlibraries; i++) {
+       FreeLibrary((HMODULE)list->libraries[i].handle);
+       if (list->libraries[i].id == 2) {
+           /* The 'custom' id involves a dynamically allocated message.
+            * Note that we must cast away the 'const' to free it. */
+           sfree((char *)list->libraries[i].gsslogmsg);
+       }
+    }
+    sfree(list->libraries);
+    sfree(list);
 }
 
 static Ssh_gss_stat ssh_sspi_indicate_mech(struct ssh_gss_library *lib,