Normalise Unequal (and latin.c) so that solver diagnostics start
[sgt/puzzles] / latin.c
1 #include <assert.h>
2 #include <string.h>
3 #include <stdarg.h>
4
5 #include "puzzles.h"
6 #include "tree234.h"
7 #include "maxflow.h"
8
9 #ifdef STANDALONE_LATIN_TEST
10 #define STANDALONE_SOLVER
11 #endif
12
13 #include "latin.h"
14
15 /* --------------------------------------------------------
16 * Solver.
17 */
18
19 #ifdef STANDALONE_SOLVER
20 int solver_show_working, solver_recurse_depth;
21 #endif
22
23 /*
24 * Function called when we are certain that a particular square has
25 * a particular number in it. The y-coordinate passed in here is
26 * transformed.
27 */
28 void latin_solver_place(struct latin_solver *solver, int x, int y, int n)
29 {
30 int i, o = solver->o;
31
32 assert(n <= o);
33 assert(cube(x,y,n));
34
35 /*
36 * Rule out all other numbers in this square.
37 */
38 for (i = 1; i <= o; i++)
39 if (i != n)
40 cube(x,y,i) = FALSE;
41
42 /*
43 * Rule out this number in all other positions in the row.
44 */
45 for (i = 0; i < o; i++)
46 if (i != y)
47 cube(x,i,n) = FALSE;
48
49 /*
50 * Rule out this number in all other positions in the column.
51 */
52 for (i = 0; i < o; i++)
53 if (i != x)
54 cube(i,y,n) = FALSE;
55
56 /*
57 * Enter the number in the result grid.
58 */
59 solver->grid[YUNTRANS(y)*o+x] = n;
60
61 /*
62 * Cross out this number from the list of numbers left to place
63 * in its row, its column and its block.
64 */
65 solver->row[y*o+n-1] = solver->col[x*o+n-1] = TRUE;
66 }
67
68 int latin_solver_elim(struct latin_solver *solver, int start, int step
69 #ifdef STANDALONE_SOLVER
70 , char *fmt, ...
71 #endif
72 )
73 {
74 int o = solver->o;
75 int fpos, m, i;
76
77 /*
78 * Count the number of set bits within this section of the
79 * cube.
80 */
81 m = 0;
82 fpos = -1;
83 for (i = 0; i < o; i++)
84 if (solver->cube[start+i*step]) {
85 fpos = start+i*step;
86 m++;
87 }
88
89 if (m == 1) {
90 int x, y, n;
91 assert(fpos >= 0);
92
93 n = 1 + fpos % o;
94 y = fpos / o;
95 x = y / o;
96 y %= o;
97
98 if (!solver->grid[YUNTRANS(y)*o+x]) {
99 #ifdef STANDALONE_SOLVER
100 if (solver_show_working) {
101 va_list ap;
102 printf("%*s", solver_recurse_depth*4, "");
103 va_start(ap, fmt);
104 vprintf(fmt, ap);
105 va_end(ap);
106 printf(":\n%*s placing %d at (%d,%d)\n",
107 solver_recurse_depth*4, "", n, x+1, YUNTRANS(y)+1);
108 }
109 #endif
110 latin_solver_place(solver, x, y, n);
111 return +1;
112 }
113 } else if (m == 0) {
114 #ifdef STANDALONE_SOLVER
115 if (solver_show_working) {
116 va_list ap;
117 printf("%*s", solver_recurse_depth*4, "");
118 va_start(ap, fmt);
119 vprintf(fmt, ap);
120 va_end(ap);
121 printf(":\n%*s no possibilities available\n",
122 solver_recurse_depth*4, "");
123 }
124 #endif
125 return -1;
126 }
127
128 return 0;
129 }
130
131 struct latin_solver_scratch {
132 unsigned char *grid, *rowidx, *colidx, *set;
133 int *neighbours, *bfsqueue;
134 #ifdef STANDALONE_SOLVER
135 int *bfsprev;
136 #endif
137 };
138
139 int latin_solver_set(struct latin_solver *solver,
140 struct latin_solver_scratch *scratch,
141 int start, int step1, int step2
142 #ifdef STANDALONE_SOLVER
143 , char *fmt, ...
144 #endif
145 )
146 {
147 int o = solver->o;
148 int i, j, n, count;
149 unsigned char *grid = scratch->grid;
150 unsigned char *rowidx = scratch->rowidx;
151 unsigned char *colidx = scratch->colidx;
152 unsigned char *set = scratch->set;
153
154 /*
155 * We are passed a o-by-o matrix of booleans. Our first job
156 * is to winnow it by finding any definite placements - i.e.
157 * any row with a solitary 1 - and discarding that row and the
158 * column containing the 1.
159 */
160 memset(rowidx, TRUE, o);
161 memset(colidx, TRUE, o);
162 for (i = 0; i < o; i++) {
163 int count = 0, first = -1;
164 for (j = 0; j < o; j++)
165 if (solver->cube[start+i*step1+j*step2])
166 first = j, count++;
167
168 if (count == 0) return -1;
169 if (count == 1)
170 rowidx[i] = colidx[first] = FALSE;
171 }
172
173 /*
174 * Convert each of rowidx/colidx from a list of 0s and 1s to a
175 * list of the indices of the 1s.
176 */
177 for (i = j = 0; i < o; i++)
178 if (rowidx[i])
179 rowidx[j++] = i;
180 n = j;
181 for (i = j = 0; i < o; i++)
182 if (colidx[i])
183 colidx[j++] = i;
184 assert(n == j);
185
186 /*
187 * And create the smaller matrix.
188 */
189 for (i = 0; i < n; i++)
190 for (j = 0; j < n; j++)
191 grid[i*o+j] = solver->cube[start+rowidx[i]*step1+colidx[j]*step2];
192
193 /*
194 * Having done that, we now have a matrix in which every row
195 * has at least two 1s in. Now we search to see if we can find
196 * a rectangle of zeroes (in the set-theoretic sense of
197 * `rectangle', i.e. a subset of rows crossed with a subset of
198 * columns) whose width and height add up to n.
199 */
200
201 memset(set, 0, n);
202 count = 0;
203 while (1) {
204 /*
205 * We have a candidate set. If its size is <=1 or >=n-1
206 * then we move on immediately.
207 */
208 if (count > 1 && count < n-1) {
209 /*
210 * The number of rows we need is n-count. See if we can
211 * find that many rows which each have a zero in all
212 * the positions listed in `set'.
213 */
214 int rows = 0;
215 for (i = 0; i < n; i++) {
216 int ok = TRUE;
217 for (j = 0; j < n; j++)
218 if (set[j] && grid[i*o+j]) {
219 ok = FALSE;
220 break;
221 }
222 if (ok)
223 rows++;
224 }
225
226 /*
227 * We expect never to be able to get _more_ than
228 * n-count suitable rows: this would imply that (for
229 * example) there are four numbers which between them
230 * have at most three possible positions, and hence it
231 * indicates a faulty deduction before this point or
232 * even a bogus clue.
233 */
234 if (rows > n - count) {
235 #ifdef STANDALONE_SOLVER
236 if (solver_show_working) {
237 va_list ap;
238 printf("%*s", solver_recurse_depth*4,
239 "");
240 va_start(ap, fmt);
241 vprintf(fmt, ap);
242 va_end(ap);
243 printf(":\n%*s contradiction reached\n",
244 solver_recurse_depth*4, "");
245 }
246 #endif
247 return -1;
248 }
249
250 if (rows >= n - count) {
251 int progress = FALSE;
252
253 /*
254 * We've got one! Now, for each row which _doesn't_
255 * satisfy the criterion, eliminate all its set
256 * bits in the positions _not_ listed in `set'.
257 * Return +1 (meaning progress has been made) if we
258 * successfully eliminated anything at all.
259 *
260 * This involves referring back through
261 * rowidx/colidx in order to work out which actual
262 * positions in the cube to meddle with.
263 */
264 for (i = 0; i < n; i++) {
265 int ok = TRUE;
266 for (j = 0; j < n; j++)
267 if (set[j] && grid[i*o+j]) {
268 ok = FALSE;
269 break;
270 }
271 if (!ok) {
272 for (j = 0; j < n; j++)
273 if (!set[j] && grid[i*o+j]) {
274 int fpos = (start+rowidx[i]*step1+
275 colidx[j]*step2);
276 #ifdef STANDALONE_SOLVER
277 if (solver_show_working) {
278 int px, py, pn;
279
280 if (!progress) {
281 va_list ap;
282 printf("%*s", solver_recurse_depth*4,
283 "");
284 va_start(ap, fmt);
285 vprintf(fmt, ap);
286 va_end(ap);
287 printf(":\n");
288 }
289
290 pn = 1 + fpos % o;
291 py = fpos / o;
292 px = py / o;
293 py %= o;
294
295 printf("%*s ruling out %d at (%d,%d)\n",
296 solver_recurse_depth*4, "",
297 pn, px+1, YUNTRANS(py)+1);
298 }
299 #endif
300 progress = TRUE;
301 solver->cube[fpos] = FALSE;
302 }
303 }
304 }
305
306 if (progress) {
307 return +1;
308 }
309 }
310 }
311
312 /*
313 * Binary increment: change the rightmost 0 to a 1, and
314 * change all 1s to the right of it to 0s.
315 */
316 i = n;
317 while (i > 0 && set[i-1])
318 set[--i] = 0, count--;
319 if (i > 0)
320 set[--i] = 1, count++;
321 else
322 break; /* done */
323 }
324
325 return 0;
326 }
327
328 /*
329 * Look for forcing chains. A forcing chain is a path of
330 * pairwise-exclusive squares (i.e. each pair of adjacent squares
331 * in the path are in the same row, column or block) with the
332 * following properties:
333 *
334 * (a) Each square on the path has precisely two possible numbers.
335 *
336 * (b) Each pair of squares which are adjacent on the path share
337 * at least one possible number in common.
338 *
339 * (c) Each square in the middle of the path shares _both_ of its
340 * numbers with at least one of its neighbours (not the same
341 * one with both neighbours).
342 *
343 * These together imply that at least one of the possible number
344 * choices at one end of the path forces _all_ the rest of the
345 * numbers along the path. In order to make real use of this, we
346 * need further properties:
347 *
348 * (c) Ruling out some number N from the square at one end
349 * of the path forces the square at the other end to
350 * take number N.
351 *
352 * (d) The two end squares are both in line with some third
353 * square.
354 *
355 * (e) That third square currently has N as a possibility.
356 *
357 * If we can find all of that lot, we can deduce that at least one
358 * of the two ends of the forcing chain has number N, and that
359 * therefore the mutually adjacent third square does not.
360 *
361 * To find forcing chains, we're going to start a bfs at each
362 * suitable square, once for each of its two possible numbers.
363 */
364 int latin_solver_forcing(struct latin_solver *solver,
365 struct latin_solver_scratch *scratch)
366 {
367 int o = solver->o;
368 int *bfsqueue = scratch->bfsqueue;
369 #ifdef STANDALONE_SOLVER
370 int *bfsprev = scratch->bfsprev;
371 #endif
372 unsigned char *number = scratch->grid;
373 int *neighbours = scratch->neighbours;
374 int x, y;
375
376 for (y = 0; y < o; y++)
377 for (x = 0; x < o; x++) {
378 int count, t, n;
379
380 /*
381 * If this square doesn't have exactly two candidate
382 * numbers, don't try it.
383 *
384 * In this loop we also sum the candidate numbers,
385 * which is a nasty hack to allow us to quickly find
386 * `the other one' (since we will shortly know there
387 * are exactly two).
388 */
389 for (count = t = 0, n = 1; n <= o; n++)
390 if (cube(x, y, n))
391 count++, t += n;
392 if (count != 2)
393 continue;
394
395 /*
396 * Now attempt a bfs for each candidate.
397 */
398 for (n = 1; n <= o; n++)
399 if (cube(x, y, n)) {
400 int orign, currn, head, tail;
401
402 /*
403 * Begin a bfs.
404 */
405 orign = n;
406
407 memset(number, o+1, o*o);
408 head = tail = 0;
409 bfsqueue[tail++] = y*o+x;
410 #ifdef STANDALONE_SOLVER
411 bfsprev[y*o+x] = -1;
412 #endif
413 number[y*o+x] = t - n;
414
415 while (head < tail) {
416 int xx, yy, nneighbours, xt, yt, i;
417
418 xx = bfsqueue[head++];
419 yy = xx / o;
420 xx %= o;
421
422 currn = number[yy*o+xx];
423
424 /*
425 * Find neighbours of yy,xx.
426 */
427 nneighbours = 0;
428 for (yt = 0; yt < o; yt++)
429 neighbours[nneighbours++] = yt*o+xx;
430 for (xt = 0; xt < o; xt++)
431 neighbours[nneighbours++] = yy*o+xt;
432
433 /*
434 * Try visiting each of those neighbours.
435 */
436 for (i = 0; i < nneighbours; i++) {
437 int cc, tt, nn;
438
439 xt = neighbours[i] % o;
440 yt = neighbours[i] / o;
441
442 /*
443 * We need this square to not be
444 * already visited, and to include
445 * currn as a possible number.
446 */
447 if (number[yt*o+xt] <= o)
448 continue;
449 if (!cube(xt, yt, currn))
450 continue;
451
452 /*
453 * Don't visit _this_ square a second
454 * time!
455 */
456 if (xt == xx && yt == yy)
457 continue;
458
459 /*
460 * To continue with the bfs, we need
461 * this square to have exactly two
462 * possible numbers.
463 */
464 for (cc = tt = 0, nn = 1; nn <= o; nn++)
465 if (cube(xt, yt, nn))
466 cc++, tt += nn;
467 if (cc == 2) {
468 bfsqueue[tail++] = yt*o+xt;
469 #ifdef STANDALONE_SOLVER
470 bfsprev[yt*o+xt] = yy*o+xx;
471 #endif
472 number[yt*o+xt] = tt - currn;
473 }
474
475 /*
476 * One other possibility is that this
477 * might be the square in which we can
478 * make a real deduction: if it's
479 * adjacent to x,y, and currn is equal
480 * to the original number we ruled out.
481 */
482 if (currn == orign &&
483 (xt == x || yt == y)) {
484 #ifdef STANDALONE_SOLVER
485 if (solver_show_working) {
486 char *sep = "";
487 int xl, yl;
488 printf("%*sforcing chain, %d at ends of ",
489 solver_recurse_depth*4, "", orign);
490 xl = xx;
491 yl = yy;
492 while (1) {
493 printf("%s(%d,%d)", sep, xl+1,
494 YUNTRANS(yl)+1);
495 xl = bfsprev[yl*o+xl];
496 if (xl < 0)
497 break;
498 yl = xl / o;
499 xl %= o;
500 sep = "-";
501 }
502 printf("\n%*s ruling out %d at (%d,%d)\n",
503 solver_recurse_depth*4, "",
504 orign, xt+1, YUNTRANS(yt)+1);
505 }
506 #endif
507 cube(xt, yt, orign) = FALSE;
508 return 1;
509 }
510 }
511 }
512 }
513 }
514
515 return 0;
516 }
517
518 struct latin_solver_scratch *latin_solver_new_scratch(struct latin_solver *solver)
519 {
520 struct latin_solver_scratch *scratch = snew(struct latin_solver_scratch);
521 int o = solver->o;
522 scratch->grid = snewn(o*o, unsigned char);
523 scratch->rowidx = snewn(o, unsigned char);
524 scratch->colidx = snewn(o, unsigned char);
525 scratch->set = snewn(o, unsigned char);
526 scratch->neighbours = snewn(3*o, int);
527 scratch->bfsqueue = snewn(o*o, int);
528 #ifdef STANDALONE_SOLVER
529 scratch->bfsprev = snewn(o*o, int);
530 #endif
531 return scratch;
532 }
533
534 void latin_solver_free_scratch(struct latin_solver_scratch *scratch)
535 {
536 #ifdef STANDALONE_SOLVER
537 sfree(scratch->bfsprev);
538 #endif
539 sfree(scratch->bfsqueue);
540 sfree(scratch->neighbours);
541 sfree(scratch->set);
542 sfree(scratch->colidx);
543 sfree(scratch->rowidx);
544 sfree(scratch->grid);
545 sfree(scratch);
546 }
547
548 void latin_solver_alloc(struct latin_solver *solver, digit *grid, int o)
549 {
550 int x, y;
551
552 solver->o = o;
553 solver->cube = snewn(o*o*o, unsigned char);
554 solver->grid = grid; /* write straight back to the input */
555 memset(solver->cube, TRUE, o*o*o);
556
557 solver->row = snewn(o*o, unsigned char);
558 solver->col = snewn(o*o, unsigned char);
559 memset(solver->row, FALSE, o*o);
560 memset(solver->col, FALSE, o*o);
561
562 for (x = 0; x < o; x++)
563 for (y = 0; y < o; y++)
564 if (grid[y*o+x])
565 latin_solver_place(solver, x, YTRANS(y), grid[y*o+x]);
566 }
567
568 void latin_solver_free(struct latin_solver *solver)
569 {
570 sfree(solver->cube);
571 sfree(solver->row);
572 sfree(solver->col);
573 }
574
575 int latin_solver_diff_simple(struct latin_solver *solver)
576 {
577 int x, y, n, ret, o = solver->o;
578 /*
579 * Row-wise positional elimination.
580 */
581 for (y = 0; y < o; y++)
582 for (n = 1; n <= o; n++)
583 if (!solver->row[y*o+n-1]) {
584 ret = latin_solver_elim(solver, cubepos(0,y,n), o*o
585 #ifdef STANDALONE_SOLVER
586 , "positional elimination,"
587 " %d in row %d", n, YUNTRANS(y)+1
588 #endif
589 );
590 if (ret != 0) return ret;
591 }
592 /*
593 * Column-wise positional elimination.
594 */
595 for (x = 0; x < o; x++)
596 for (n = 1; n <= o; n++)
597 if (!solver->col[x*o+n-1]) {
598 ret = latin_solver_elim(solver, cubepos(x,0,n), o
599 #ifdef STANDALONE_SOLVER
600 , "positional elimination,"
601 " %d in column %d", n, x+1
602 #endif
603 );
604 if (ret != 0) return ret;
605 }
606
607 /*
608 * Numeric elimination.
609 */
610 for (x = 0; x < o; x++)
611 for (y = 0; y < o; y++)
612 if (!solver->grid[YUNTRANS(y)*o+x]) {
613 ret = latin_solver_elim(solver, cubepos(x,y,1), 1
614 #ifdef STANDALONE_SOLVER
615 , "numeric elimination at (%d,%d)",
616 x+1, YUNTRANS(y)+1
617 #endif
618 );
619 if (ret != 0) return ret;
620 }
621 return 0;
622 }
623
624 int latin_solver_diff_set(struct latin_solver *solver,
625 struct latin_solver_scratch *scratch,
626 int extreme)
627 {
628 int x, y, n, ret, o = solver->o;
629
630 if (!extreme) {
631 /*
632 * Row-wise set elimination.
633 */
634 for (y = 0; y < o; y++) {
635 ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
636 #ifdef STANDALONE_SOLVER
637 , "set elimination, row %d", YUNTRANS(y)+1
638 #endif
639 );
640 if (ret != 0) return ret;
641 }
642 /*
643 * Column-wise set elimination.
644 */
645 for (x = 0; x < o; x++) {
646 ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
647 #ifdef STANDALONE_SOLVER
648 , "set elimination, column %d", x+1
649 #endif
650 );
651 if (ret != 0) return ret;
652 }
653 } else {
654 /*
655 * Row-vs-column set elimination on a single number
656 * (much tricker for a human to do!)
657 */
658 for (n = 1; n <= o; n++) {
659 ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
660 #ifdef STANDALONE_SOLVER
661 , "positional set elimination, number %d", n
662 #endif
663 );
664 if (ret != 0) return ret;
665 }
666 }
667 return 0;
668 }
669
670 /*
671 * Returns:
672 * 0 for 'didn't do anything' implying it was already solved.
673 * -1 for 'impossible' (no solution)
674 * 1 for 'single solution'
675 * >1 for 'multiple solutions' (you don't get to know how many, and
676 * the first such solution found will be set.
677 *
678 * and this function may well assert if given an impossible board.
679 */
680 static int latin_solver_recurse
681 (struct latin_solver *solver, int diff_simple, int diff_set_0,
682 int diff_set_1, int diff_forcing, int diff_recursive,
683 usersolver_t const *usersolvers, void *ctx,
684 ctxnew_t ctxnew, ctxfree_t ctxfree)
685 {
686 int best, bestcount;
687 int o = solver->o, x, y, n;
688
689 best = -1;
690 bestcount = o+1;
691
692 for (y = 0; y < o; y++)
693 for (x = 0; x < o; x++)
694 if (!solver->grid[y*o+x]) {
695 int count;
696
697 /*
698 * An unfilled square. Count the number of
699 * possible digits in it.
700 */
701 count = 0;
702 for (n = 1; n <= o; n++)
703 if (cube(x,YTRANS(y),n))
704 count++;
705
706 /*
707 * We should have found any impossibilities
708 * already, so this can safely be an assert.
709 */
710 assert(count > 1);
711
712 if (count < bestcount) {
713 bestcount = count;
714 best = y*o+x;
715 }
716 }
717
718 if (best == -1)
719 /* we were complete already. */
720 return 0;
721 else {
722 int i, j;
723 digit *list, *ingrid, *outgrid;
724 int diff = diff_impossible; /* no solution found yet */
725
726 /*
727 * Attempt recursion.
728 */
729 y = best / o;
730 x = best % o;
731
732 list = snewn(o, digit);
733 ingrid = snewn(o*o, digit);
734 outgrid = snewn(o*o, digit);
735 memcpy(ingrid, solver->grid, o*o);
736
737 /* Make a list of the possible digits. */
738 for (j = 0, n = 1; n <= o; n++)
739 if (cube(x,YTRANS(y),n))
740 list[j++] = n;
741
742 #ifdef STANDALONE_SOLVER
743 if (solver_show_working) {
744 char *sep = "";
745 printf("%*srecursing on (%d,%d) [",
746 solver_recurse_depth*4, "", x+1, y+1);
747 for (i = 0; i < j; i++) {
748 printf("%s%d", sep, list[i]);
749 sep = " or ";
750 }
751 printf("]\n");
752 }
753 #endif
754
755 /*
756 * And step along the list, recursing back into the
757 * main solver at every stage.
758 */
759 for (i = 0; i < j; i++) {
760 int ret;
761 void *newctx;
762
763 memcpy(outgrid, ingrid, o*o);
764 outgrid[y*o+x] = list[i];
765
766 #ifdef STANDALONE_SOLVER
767 if (solver_show_working)
768 printf("%*sguessing %d at (%d,%d)\n",
769 solver_recurse_depth*4, "", list[i], x+1, y+1);
770 solver_recurse_depth++;
771 #endif
772
773 if (ctxnew) {
774 newctx = ctxnew(ctx);
775 } else {
776 newctx = ctx;
777 }
778 ret = latin_solver(outgrid, o, diff_recursive,
779 diff_simple, diff_set_0, diff_set_1,
780 diff_forcing, diff_recursive,
781 usersolvers, newctx, ctxnew, ctxfree);
782 if (ctxnew)
783 ctxfree(newctx);
784
785 #ifdef STANDALONE_SOLVER
786 solver_recurse_depth--;
787 if (solver_show_working) {
788 printf("%*sretracting %d at (%d,%d)\n",
789 solver_recurse_depth*4, "", list[i], x+1, y+1);
790 }
791 #endif
792 /* we recurse as deep as we can, so we should never find
793 * find ourselves giving up on a puzzle without declaring it
794 * impossible. */
795 assert(ret != diff_unfinished);
796
797 /*
798 * If we have our first solution, copy it into the
799 * grid we will return.
800 */
801 if (diff == diff_impossible && ret != diff_impossible)
802 memcpy(solver->grid, outgrid, o*o);
803
804 if (ret == diff_ambiguous)
805 diff = diff_ambiguous;
806 else if (ret == diff_impossible)
807 /* do not change our return value */;
808 else {
809 /* the recursion turned up exactly one solution */
810 if (diff == diff_impossible)
811 diff = diff_recursive;
812 else
813 diff = diff_ambiguous;
814 }
815
816 /*
817 * As soon as we've found more than one solution,
818 * give up immediately.
819 */
820 if (diff == diff_ambiguous)
821 break;
822 }
823
824 sfree(outgrid);
825 sfree(ingrid);
826 sfree(list);
827
828 if (diff == diff_impossible)
829 return -1;
830 else if (diff == diff_ambiguous)
831 return 2;
832 else {
833 assert(diff == diff_recursive);
834 return 1;
835 }
836 }
837 }
838
839 int latin_solver_main(struct latin_solver *solver, int maxdiff,
840 int diff_simple, int diff_set_0, int diff_set_1,
841 int diff_forcing, int diff_recursive,
842 usersolver_t const *usersolvers, void *ctx,
843 ctxnew_t ctxnew, ctxfree_t ctxfree)
844 {
845 struct latin_solver_scratch *scratch = latin_solver_new_scratch(solver);
846 int ret, diff = diff_simple;
847
848 assert(maxdiff <= diff_recursive);
849 /*
850 * Now loop over the grid repeatedly trying all permitted modes
851 * of reasoning. The loop terminates if we complete an
852 * iteration without making any progress; we then return
853 * failure or success depending on whether the grid is full or
854 * not.
855 */
856 while (1) {
857 int i;
858
859 cont:
860
861 latin_solver_debug(solver->cube, solver->o);
862
863 for (i = 0; i <= maxdiff; i++) {
864 if (usersolvers[i])
865 ret = usersolvers[i](solver, ctx);
866 else
867 ret = 0;
868 if (ret == 0 && i == diff_simple)
869 ret = latin_solver_diff_simple(solver);
870 if (ret == 0 && i == diff_set_0)
871 ret = latin_solver_diff_set(solver, scratch, 0);
872 if (ret == 0 && i == diff_set_1)
873 ret = latin_solver_diff_set(solver, scratch, 1);
874 if (ret == 0 && i == diff_forcing)
875 ret = latin_solver_forcing(solver, scratch);
876
877 if (ret < 0) {
878 diff = diff_impossible;
879 goto got_result;
880 } else if (ret > 0) {
881 diff = max(diff, i);
882 goto cont;
883 }
884 }
885
886 /*
887 * If we reach here, we have made no deductions in this
888 * iteration, so the algorithm terminates.
889 */
890 break;
891 }
892
893 /*
894 * Last chance: if we haven't fully solved the puzzle yet, try
895 * recursing based on guesses for a particular square. We pick
896 * one of the most constrained empty squares we can find, which
897 * has the effect of pruning the search tree as much as
898 * possible.
899 */
900 if (maxdiff == diff_recursive) {
901 int nsol = latin_solver_recurse(solver,
902 diff_simple, diff_set_0, diff_set_1,
903 diff_forcing, diff_recursive,
904 usersolvers, ctx, ctxnew, ctxfree);
905 if (nsol < 0) diff = diff_impossible;
906 else if (nsol == 1) diff = diff_recursive;
907 else if (nsol > 1) diff = diff_ambiguous;
908 /* if nsol == 0 then we were complete anyway
909 * (and thus don't need to change diff) */
910 } else {
911 /*
912 * We're forbidden to use recursion, so we just see whether
913 * our grid is fully solved, and return diff_unfinished
914 * otherwise.
915 */
916 int x, y, o = solver->o;
917
918 for (y = 0; y < o; y++)
919 for (x = 0; x < o; x++)
920 if (!solver->grid[y*o+x])
921 diff = diff_unfinished;
922 }
923
924 got_result:
925
926 #ifdef STANDALONE_SOLVER
927 if (solver_show_working)
928 printf("%*s%s found\n",
929 solver_recurse_depth*4, "",
930 diff == diff_impossible ? "no solution (impossible)" :
931 diff == diff_unfinished ? "no solution (unfinished)" :
932 diff == diff_ambiguous ? "multiple solutions" :
933 "one solution");
934 #endif
935
936 latin_solver_free_scratch(scratch);
937
938 return diff;
939 }
940
941 int latin_solver(digit *grid, int o, int maxdiff,
942 int diff_simple, int diff_set_0, int diff_set_1,
943 int diff_forcing, int diff_recursive,
944 usersolver_t const *usersolvers, void *ctx,
945 ctxnew_t ctxnew, ctxfree_t ctxfree)
946 {
947 struct latin_solver solver;
948 int diff;
949
950 latin_solver_alloc(&solver, grid, o);
951 diff = latin_solver_main(&solver, maxdiff,
952 diff_simple, diff_set_0, diff_set_1,
953 diff_forcing, diff_recursive,
954 usersolvers, ctx, ctxnew, ctxfree);
955 latin_solver_free(&solver);
956 return diff;
957 }
958
959 void latin_solver_debug(unsigned char *cube, int o)
960 {
961 #ifdef STANDALONE_SOLVER
962 if (solver_show_working > 1) {
963 struct latin_solver ls, *solver = &ls;
964 char *dbg;
965 int x, y, i, c = 0;
966
967 ls.cube = cube; ls.o = o; /* for cube() to work */
968
969 dbg = snewn(3*o*o*o, char);
970 for (y = 0; y < o; y++) {
971 for (x = 0; x < o; x++) {
972 for (i = 1; i <= o; i++) {
973 if (cube(x,y,i))
974 dbg[c++] = i + '0';
975 else
976 dbg[c++] = '.';
977 }
978 dbg[c++] = ' ';
979 }
980 dbg[c++] = '\n';
981 }
982 dbg[c++] = '\n';
983 dbg[c++] = '\0';
984
985 printf("%s", dbg);
986 sfree(dbg);
987 }
988 #endif
989 }
990
991 void latin_debug(digit *sq, int o)
992 {
993 #ifdef STANDALONE_SOLVER
994 if (solver_show_working) {
995 int x, y;
996
997 for (y = 0; y < o; y++) {
998 for (x = 0; x < o; x++) {
999 printf("%2d ", sq[y*o+x]);
1000 }
1001 printf("\n");
1002 }
1003 printf("\n");
1004 }
1005 #endif
1006 }
1007
1008 /* --------------------------------------------------------
1009 * Generation.
1010 */
1011
1012 digit *latin_generate(int o, random_state *rs)
1013 {
1014 digit *sq;
1015 int *edges, *backedges, *capacity, *flow;
1016 void *scratch;
1017 int ne, scratchsize;
1018 int i, j, k;
1019 digit *row, *col, *numinv, *num;
1020
1021 /*
1022 * To efficiently generate a latin square in such a way that
1023 * all possible squares are possible outputs from the function,
1024 * we make use of a theorem which states that any r x n latin
1025 * rectangle, with r < n, can be extended into an (r+1) x n
1026 * latin rectangle. In other words, we can reliably generate a
1027 * latin square row by row, by at every stage writing down any
1028 * row at all which doesn't conflict with previous rows, and
1029 * the theorem guarantees that we will never have to backtrack.
1030 *
1031 * To find a viable row at each stage, we can make use of the
1032 * support functions in maxflow.c.
1033 */
1034
1035 sq = snewn(o*o, digit);
1036
1037 /*
1038 * In case this method of generation introduces a really subtle
1039 * top-to-bottom directional bias, we'll generate the rows in
1040 * random order.
1041 */
1042 row = snewn(o, digit);
1043 col = snewn(o, digit);
1044 numinv = snewn(o, digit);
1045 num = snewn(o, digit);
1046 for (i = 0; i < o; i++)
1047 row[i] = i;
1048 shuffle(row, i, sizeof(*row), rs);
1049
1050 /*
1051 * Set up the infrastructure for the maxflow algorithm.
1052 */
1053 scratchsize = maxflow_scratch_size(o * 2 + 2);
1054 scratch = smalloc(scratchsize);
1055 backedges = snewn(o*o + 2*o, int);
1056 edges = snewn((o*o + 2*o) * 2, int);
1057 capacity = snewn(o*o + 2*o, int);
1058 flow = snewn(o*o + 2*o, int);
1059 /* Set up the edge array, and the initial capacities. */
1060 ne = 0;
1061 for (i = 0; i < o; i++) {
1062 /* Each LHS vertex is connected to all RHS vertices. */
1063 for (j = 0; j < o; j++) {
1064 edges[ne*2] = i;
1065 edges[ne*2+1] = j+o;
1066 /* capacity for this edge is set later on */
1067 ne++;
1068 }
1069 }
1070 for (i = 0; i < o; i++) {
1071 /* Each RHS vertex is connected to the distinguished sink vertex. */
1072 edges[ne*2] = i+o;
1073 edges[ne*2+1] = o*2+1;
1074 capacity[ne] = 1;
1075 ne++;
1076 }
1077 for (i = 0; i < o; i++) {
1078 /* And the distinguished source vertex connects to each LHS vertex. */
1079 edges[ne*2] = o*2;
1080 edges[ne*2+1] = i;
1081 capacity[ne] = 1;
1082 ne++;
1083 }
1084 assert(ne == o*o + 2*o);
1085 /* Now set up backedges. */
1086 maxflow_setup_backedges(ne, edges, backedges);
1087
1088 /*
1089 * Now generate each row of the latin square.
1090 */
1091 for (i = 0; i < o; i++) {
1092 /*
1093 * To prevent maxflow from behaving deterministically, we
1094 * separately permute the columns and the digits for the
1095 * purposes of the algorithm, differently for every row.
1096 */
1097 for (j = 0; j < o; j++)
1098 col[j] = num[j] = j;
1099 shuffle(col, j, sizeof(*col), rs);
1100 shuffle(num, j, sizeof(*num), rs);
1101 /* We need the num permutation in both forward and inverse forms. */
1102 for (j = 0; j < o; j++)
1103 numinv[num[j]] = j;
1104
1105 /*
1106 * Set up the capacities for the maxflow run, by examining
1107 * the existing latin square.
1108 */
1109 for (j = 0; j < o*o; j++)
1110 capacity[j] = 1;
1111 for (j = 0; j < i; j++)
1112 for (k = 0; k < o; k++) {
1113 int n = num[sq[row[j]*o + col[k]] - 1];
1114 capacity[k*o+n] = 0;
1115 }
1116
1117 /*
1118 * Run maxflow.
1119 */
1120 j = maxflow_with_scratch(scratch, o*2+2, 2*o, 2*o+1, ne,
1121 edges, backedges, capacity, flow, NULL);
1122 assert(j == o); /* by the above theorem, this must have succeeded */
1123
1124 /*
1125 * And examine the flow array to pick out the new row of
1126 * the latin square.
1127 */
1128 for (j = 0; j < o; j++) {
1129 for (k = 0; k < o; k++) {
1130 if (flow[j*o+k])
1131 break;
1132 }
1133 assert(k < o);
1134 sq[row[i]*o + col[j]] = numinv[k] + 1;
1135 }
1136 }
1137
1138 /*
1139 * Done. Free our internal workspaces...
1140 */
1141 sfree(flow);
1142 sfree(capacity);
1143 sfree(edges);
1144 sfree(backedges);
1145 sfree(scratch);
1146 sfree(numinv);
1147 sfree(num);
1148 sfree(col);
1149 sfree(row);
1150
1151 /*
1152 * ... and return our completed latin square.
1153 */
1154 return sq;
1155 }
1156
1157 /* --------------------------------------------------------
1158 * Checking.
1159 */
1160
1161 typedef struct lcparams {
1162 digit elt;
1163 int count;
1164 } lcparams;
1165
1166 static int latin_check_cmp(void *v1, void *v2)
1167 {
1168 lcparams *lc1 = (lcparams *)v1;
1169 lcparams *lc2 = (lcparams *)v2;
1170
1171 if (lc1->elt < lc2->elt) return -1;
1172 if (lc1->elt > lc2->elt) return 1;
1173 return 0;
1174 }
1175
1176 #define ELT(sq,x,y) (sq[((y)*order)+(x)])
1177
1178 /* returns non-zero if sq is not a latin square. */
1179 int latin_check(digit *sq, int order)
1180 {
1181 tree234 *dict = newtree234(latin_check_cmp);
1182 int c, r;
1183 int ret = 0;
1184 lcparams *lcp, lc, *aret;
1185
1186 /* Use a tree234 as a simple hash table, go through the square
1187 * adding elements as we go or incrementing their counts. */
1188 for (c = 0; c < order; c++) {
1189 for (r = 0; r < order; r++) {
1190 lc.elt = ELT(sq, c, r); lc.count = 0;
1191 lcp = find234(dict, &lc, NULL);
1192 if (!lcp) {
1193 lcp = snew(lcparams);
1194 lcp->elt = ELT(sq, c, r);
1195 lcp->count = 1;
1196 aret = add234(dict, lcp);
1197 assert(aret == lcp);
1198 } else {
1199 lcp->count++;
1200 }
1201 }
1202 }
1203
1204 /* There should be precisely 'order' letters in the alphabet,
1205 * each occurring 'order' times (making the OxO tree) */
1206 if (count234(dict) != order) ret = 1;
1207 else {
1208 for (c = 0; (lcp = index234(dict, c)) != NULL; c++) {
1209 if (lcp->count != order) ret = 1;
1210 }
1211 }
1212 for (c = 0; (lcp = index234(dict, c)) != NULL; c++)
1213 sfree(lcp);
1214 freetree234(dict);
1215
1216 return ret;
1217 }
1218
1219
1220 /* --------------------------------------------------------
1221 * Testing (and printing).
1222 */
1223
1224 #ifdef STANDALONE_LATIN_TEST
1225
1226 #include <stdio.h>
1227 #include <time.h>
1228
1229 const char *quis;
1230
1231 static void latin_print(digit *sq, int order)
1232 {
1233 int x, y;
1234
1235 for (y = 0; y < order; y++) {
1236 for (x = 0; x < order; x++) {
1237 printf("%2u ", ELT(sq, x, y));
1238 }
1239 printf("\n");
1240 }
1241 printf("\n");
1242 }
1243
1244 static void gen(int order, random_state *rs, int debug)
1245 {
1246 digit *sq;
1247
1248 solver_show_working = debug;
1249
1250 sq = latin_generate(order, rs);
1251 latin_print(sq, order);
1252 if (latin_check(sq, order)) {
1253 fprintf(stderr, "Square is not a latin square!");
1254 exit(1);
1255 }
1256
1257 sfree(sq);
1258 }
1259
1260 void test_soak(int order, random_state *rs)
1261 {
1262 digit *sq;
1263 int n = 0;
1264 time_t tt_start, tt_now, tt_last;
1265
1266 solver_show_working = 0;
1267 tt_now = tt_start = time(NULL);
1268
1269 while(1) {
1270 sq = latin_generate(order, rs);
1271 sfree(sq);
1272 n++;
1273
1274 tt_last = time(NULL);
1275 if (tt_last > tt_now) {
1276 tt_now = tt_last;
1277 printf("%d total, %3.1f/s\n", n,
1278 (double)n / (double)(tt_now - tt_start));
1279 }
1280 }
1281 }
1282
1283 void usage_exit(const char *msg)
1284 {
1285 if (msg)
1286 fprintf(stderr, "%s: %s\n", quis, msg);
1287 fprintf(stderr, "Usage: %s [--seed SEED] --soak <params> | [game_id [game_id ...]]\n", quis);
1288 exit(1);
1289 }
1290
1291 int main(int argc, char *argv[])
1292 {
1293 int i, soak = 0;
1294 random_state *rs;
1295 time_t seed = time(NULL);
1296
1297 quis = argv[0];
1298 while (--argc > 0) {
1299 const char *p = *++argv;
1300 if (!strcmp(p, "--soak"))
1301 soak = 1;
1302 else if (!strcmp(p, "--seed")) {
1303 if (argc == 0)
1304 usage_exit("--seed needs an argument");
1305 seed = (time_t)atoi(*++argv);
1306 argc--;
1307 } else if (*p == '-')
1308 usage_exit("unrecognised option");
1309 else
1310 break; /* finished options */
1311 }
1312
1313 rs = random_new((void*)&seed, sizeof(time_t));
1314
1315 if (soak == 1) {
1316 if (argc != 1) usage_exit("only one argument for --soak");
1317 test_soak(atoi(*argv), rs);
1318 } else {
1319 if (argc > 0) {
1320 for (i = 0; i < argc; i++) {
1321 gen(atoi(*argv++), rs, 1);
1322 }
1323 } else {
1324 while (1) {
1325 i = random_upto(rs, 20) + 1;
1326 gen(i, rs, 0);
1327 }
1328 }
1329 }
1330 random_free(rs);
1331 return 0;
1332 }
1333
1334 #endif
1335
1336 /* vim: set shiftwidth=4 tabstop=8: */