Introduce, and implement as usefully as I can in all front ends, a
[sgt/puzzles] / latin.c
diff --git a/latin.c b/latin.c
index 4f6e1f3..34c06c4 100644 (file)
--- a/latin.c
+++ b/latin.c
@@ -619,55 +619,52 @@ int latin_solver_diff_simple(struct latin_solver *solver)
 
 int latin_solver_diff_set(struct latin_solver *solver,
                           struct latin_solver_scratch *scratch,
-                          int *extreme)
+                          int extreme)
 {
     int x, y, n, ret, o = solver->o;
-    /*
-     * Row-wise set elimination.
-     */
-    for (y = 0; y < o; y++) {
-        ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
+
+    if (!extreme) {
+        /*
+         * Row-wise set elimination.
+         */
+        for (y = 0; y < o; y++) {
+            ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
 #ifdef STANDALONE_SOLVER
-                              , "set elimination, row %d", YUNTRANS(y)
+                                   , "set elimination, row %d", YUNTRANS(y)
 #endif
-                              );
-        if (ret > 0) *extreme = 0;
-        if (ret != 0) return ret;
-    }
-
-    /*
-     * Column-wise set elimination.
-     */
-    for (x = 0; x < o; x++) {
-        ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
+                                  );
+            if (ret != 0) return ret;
+        }
+        /*
+         * Column-wise set elimination.
+         */
+        for (x = 0; x < o; x++) {
+            ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
 #ifdef STANDALONE_SOLVER
-                              , "set elimination, column %d", x
+                                   , "set elimination, column %d", x
 #endif
-                              );
-        if (ret > 0) *extreme = 0;
-        if (ret != 0) return ret;
-    }
-
-    /*
-     * Row-vs-column set elimination on a single number.
-     */
-    for (n = 1; n <= o; n++) {
-        ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
+                                  );
+            if (ret != 0) return ret;
+        }
+    } else {
+        /*
+         * Row-vs-column set elimination on a single number
+         * (much tricker for a human to do!)
+         */
+        for (n = 1; n <= o; n++) {
+            ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
 #ifdef STANDALONE_SOLVER
-                              , "positional set elimination, number %d", n
+                                   , "positional set elimination, number %d", n
 #endif
-                              );
-        if (ret > 0) *extreme = 1;
-        if (ret != 0) return ret;
+                                  );
+            if (ret != 0) return ret;
+        }
     }
     return 0;
 }
 
-/* This uses our own diff_* internally, but doesn't require callers
- * to; this is so it can be used by games that want to rewrite
- * the solver so as to use a different set of difficulties.
- *
- * It returns:
+/*
+ * Returns:
  * 0 for 'didn't do anything' implying it was already solved.
  * -1 for 'impossible' (no solution)
  * 1 for 'single solution'
@@ -676,8 +673,11 @@ int latin_solver_diff_set(struct latin_solver *solver,
  *
  * and this function may well assert if given an impossible board.
  */
-int latin_solver_recurse(struct latin_solver *solver, int recdiff,
-                         latin_solver_callback cb, void *ctx)
+static int latin_solver_recurse
+    (struct latin_solver *solver, int diff_simple, int diff_set_0,
+     int diff_set_1, int diff_forcing, int diff_recursive,
+     usersolver_t const *usersolvers, void *ctx,
+     ctxnew_t ctxnew, ctxfree_t ctxfree)
 {
     int best, bestcount;
     int o = solver->o, x, y, n;
@@ -754,6 +754,7 @@ int latin_solver_recurse(struct latin_solver *solver, int recdiff,
          */
         for (i = 0; i < j; i++) {
             int ret;
+           void *newctx;
 
             memcpy(outgrid, ingrid, o*o);
             outgrid[y*o+x] = list[i];
@@ -765,7 +766,17 @@ int latin_solver_recurse(struct latin_solver *solver, int recdiff,
             solver_recurse_depth++;
 #endif
 
-            ret = cb(outgrid, o, recdiff, ctx);
+           if (ctxnew) {
+               newctx = ctxnew(ctx);
+           } else {
+               newctx = ctx;
+           }
+            ret = latin_solver(outgrid, o, diff_recursive,
+                              diff_simple, diff_set_0, diff_set_1,
+                              diff_forcing, diff_recursive,
+                              usersolvers, newctx, ctxnew, ctxfree);
+           if (ctxnew)
+               ctxfree(newctx);
 
 #ifdef STANDALONE_SOLVER
             solver_recurse_depth--;
@@ -793,7 +804,7 @@ int latin_solver_recurse(struct latin_solver *solver, int recdiff,
             else {
                 /* the recursion turned up exactly one solution */
                 if (diff == diff_impossible)
-                    diff = recdiff;
+                    diff = diff_recursive;
                 else
                     diff = diff_ambiguous;
             }
@@ -815,18 +826,20 @@ int latin_solver_recurse(struct latin_solver *solver, int recdiff,
         else if (diff == diff_ambiguous)
             return 2;
         else {
-            assert(diff == recdiff);
+            assert(diff == diff_recursive);
             return 1;
         }
     }
 }
 
-enum { diff_simple = 1, diff_set, diff_extreme, diff_recursive };
-
-static int latin_solver_sub(struct latin_solver *solver, int maxdiff, void *ctx)
+int latin_solver_main(struct latin_solver *solver, int maxdiff,
+                     int diff_simple, int diff_set_0, int diff_set_1,
+                     int diff_forcing, int diff_recursive,
+                     usersolver_t const *usersolvers, void *ctx,
+                     ctxnew_t ctxnew, ctxfree_t ctxfree)
 {
     struct latin_solver_scratch *scratch = latin_solver_new_scratch(solver);
-    int ret, diff = diff_simple, extreme;
+    int ret, diff = diff_simple;
 
     assert(maxdiff <= diff_recursive);
     /*
@@ -837,47 +850,34 @@ static int latin_solver_sub(struct latin_solver *solver, int maxdiff, void *ctx)
      * not.
      */
     while (1) {
-        /*
-         * I'd like to write `continue;' inside each of the
-         * following loops, so that the solver returns here after
-         * making some progress. However, I can't specify that I
-         * want to continue an outer loop rather than the innermost
-         * one, so I'm apologetically resorting to a goto.
-         */
-       cont:
-        latin_solver_debug(solver->cube, solver->o);
-
-        ret = latin_solver_diff_simple(solver);
-        if (ret < 0) {
-            diff = diff_impossible;
-            goto got_result;
-        } else if (ret > 0) {
-            diff = max(diff, diff_simple);
-            goto cont;
-        }
-
-        if (maxdiff <= diff_simple)
-            break;
+       int i;
 
-        ret = latin_solver_diff_set(solver, scratch, &extreme);
-        if (ret < 0) {
-            diff = diff_impossible;
-            goto got_result;
-        } else if (ret > 0) {
-            diff = max(diff, extreme ? diff_extreme : diff_set);
-            goto cont;
-        }
+       cont:
 
-        if (maxdiff <= diff_set)
-            break;
+        latin_solver_debug(solver->cube, solver->o);
 
-        /*
-         * Forcing chains.
-         */
-        if (latin_solver_forcing(solver, scratch)) {
-            diff = max(diff, diff_extreme);
-            goto cont;
-        }
+       for (i = 0; i <= maxdiff; i++) {
+           if (usersolvers[i])
+               ret = usersolvers[i](solver, ctx);
+           else
+               ret = 0;
+           if (ret == 0 && i == diff_simple)
+               ret = latin_solver_diff_simple(solver);
+           if (ret == 0 && i == diff_set_0)
+               ret = latin_solver_diff_set(solver, scratch, 0);
+           if (ret == 0 && i == diff_set_1)
+               ret = latin_solver_diff_set(solver, scratch, 1);
+           if (ret == 0 && i == diff_forcing)
+               ret = latin_solver_forcing(solver, scratch);
+
+           if (ret < 0) {
+               diff = diff_impossible;
+               goto got_result;
+           } else if (ret > 0) {
+               diff = max(diff, i);
+               goto cont;
+           }
+       }
 
         /*
          * If we reach here, we have made no deductions in this
@@ -894,7 +894,10 @@ static int latin_solver_sub(struct latin_solver *solver, int maxdiff, void *ctx)
      * possible.
      */
     if (maxdiff == diff_recursive) {
-        int nsol = latin_solver_recurse(solver, diff_recursive, latin_solver, ctx);
+        int nsol = latin_solver_recurse(solver,
+                                       diff_simple, diff_set_0, diff_set_1,
+                                       diff_forcing, diff_recursive,
+                                       usersolvers, ctx, ctxnew, ctxfree);
         if (nsol < 0) diff = diff_impossible;
         else if (nsol == 1) diff = diff_recursive;
         else if (nsol > 1) diff = diff_ambiguous;
@@ -931,13 +934,20 @@ static int latin_solver_sub(struct latin_solver *solver, int maxdiff, void *ctx)
     return diff;
 }
 
-int latin_solver(digit *grid, int o, int maxdiff, void *ctx)
+int latin_solver(digit *grid, int o, int maxdiff,
+                int diff_simple, int diff_set_0, int diff_set_1,
+                int diff_forcing, int diff_recursive,
+                usersolver_t const *usersolvers, void *ctx,
+                ctxnew_t ctxnew, ctxfree_t ctxfree)
 {
     struct latin_solver solver;
     int diff;
 
     latin_solver_alloc(&solver, grid, o);
-    diff = latin_solver_sub(&solver, maxdiff, ctx);
+    diff = latin_solver_main(&solver, maxdiff,
+                            diff_simple, diff_set_0, diff_set_1,
+                            diff_forcing, diff_recursive,
+                            usersolvers, ctx, ctxnew, ctxfree);
     latin_solver_free(&solver);
     return diff;
 }
@@ -945,14 +955,14 @@ int latin_solver(digit *grid, int o, int maxdiff, void *ctx)
 void latin_solver_debug(unsigned char *cube, int o)
 {
 #ifdef STANDALONE_SOLVER
-    if (solver_show_working) {
+    if (solver_show_working > 1) {
         struct latin_solver ls, *solver = &ls;
-        unsigned char *dbg;
+        char *dbg;
         int x, y, i, c = 0;
 
         ls.cube = cube; ls.o = o; /* for cube() to work */
 
-        dbg = snewn(3*o*o*o, unsigned char);
+        dbg = snewn(3*o*o*o, char);
         for (y = 0; y < o; y++) {
             for (x = 0; x < o; x++) {
                 for (i = 1; i <= o; i++) {
@@ -1083,6 +1093,7 @@ digit *latin_generate(int o, random_state *rs)
        for (j = 0; j < o; j++)
            col[j] = num[j] = j;
        shuffle(col, j, sizeof(*col), rs);
+       shuffle(num, j, sizeof(*num), rs);
        /* We need the num permutation in both forward and inverse forms. */
        for (j = 0; j < o; j++)
            numinv[num[j]] = j;
@@ -1166,7 +1177,7 @@ int latin_check(digit *sq, int order)
     tree234 *dict = newtree234(latin_check_cmp);
     int c, r;
     int ret = 0;
-    lcparams *lcp, lc;
+    lcparams *lcp, lc, *aret;
 
     /* Use a tree234 as a simple hash table, go through the square
      * adding elements as we go or incrementing their counts. */
@@ -1178,7 +1189,8 @@ int latin_check(digit *sq, int order)
                lcp = snew(lcparams);
                lcp->elt = ELT(sq, c, r);
                lcp->count = 1;
-               assert(add234(dict, lcp) == lcp);
+                aret = add234(dict, lcp);
+               assert(aret == lcp);
            } else {
                lcp->count++;
            }