Refactor latin.c to make it easier to reuse. Instead of client
[sgt/puzzles] / latin.c
1 #include <assert.h>
2 #include <string.h>
3 #include <stdarg.h>
4
5 #include "puzzles.h"
6 #include "tree234.h"
7 #include "maxflow.h"
8
9 #ifdef STANDALONE_LATIN_TEST
10 #define STANDALONE_SOLVER
11 #endif
12
13 #include "latin.h"
14
15 /* --------------------------------------------------------
16 * Solver.
17 */
18
19 /*
20 * Function called when we are certain that a particular square has
21 * a particular number in it. The y-coordinate passed in here is
22 * transformed.
23 */
24 void latin_solver_place(struct latin_solver *solver, int x, int y, int n)
25 {
26 int i, o = solver->o;
27
28 assert(n <= o);
29 assert(cube(x,y,n));
30
31 /*
32 * Rule out all other numbers in this square.
33 */
34 for (i = 1; i <= o; i++)
35 if (i != n)
36 cube(x,y,i) = FALSE;
37
38 /*
39 * Rule out this number in all other positions in the row.
40 */
41 for (i = 0; i < o; i++)
42 if (i != y)
43 cube(x,i,n) = FALSE;
44
45 /*
46 * Rule out this number in all other positions in the column.
47 */
48 for (i = 0; i < o; i++)
49 if (i != x)
50 cube(i,y,n) = FALSE;
51
52 /*
53 * Enter the number in the result grid.
54 */
55 solver->grid[YUNTRANS(y)*o+x] = n;
56
57 /*
58 * Cross out this number from the list of numbers left to place
59 * in its row, its column and its block.
60 */
61 solver->row[y*o+n-1] = solver->col[x*o+n-1] = TRUE;
62 }
63
64 int latin_solver_elim(struct latin_solver *solver, int start, int step
65 #ifdef STANDALONE_SOLVER
66 , char *fmt, ...
67 #endif
68 )
69 {
70 int o = solver->o;
71 int fpos, m, i;
72
73 /*
74 * Count the number of set bits within this section of the
75 * cube.
76 */
77 m = 0;
78 fpos = -1;
79 for (i = 0; i < o; i++)
80 if (solver->cube[start+i*step]) {
81 fpos = start+i*step;
82 m++;
83 }
84
85 if (m == 1) {
86 int x, y, n;
87 assert(fpos >= 0);
88
89 n = 1 + fpos % o;
90 y = fpos / o;
91 x = y / o;
92 y %= o;
93
94 if (!solver->grid[YUNTRANS(y)*o+x]) {
95 #ifdef STANDALONE_SOLVER
96 if (solver_show_working) {
97 va_list ap;
98 printf("%*s", solver_recurse_depth*4, "");
99 va_start(ap, fmt);
100 vprintf(fmt, ap);
101 va_end(ap);
102 printf(":\n%*s placing %d at (%d,%d)\n",
103 solver_recurse_depth*4, "", n, x, YUNTRANS(y));
104 }
105 #endif
106 latin_solver_place(solver, x, y, n);
107 return +1;
108 }
109 } else if (m == 0) {
110 #ifdef STANDALONE_SOLVER
111 if (solver_show_working) {
112 va_list ap;
113 printf("%*s", solver_recurse_depth*4, "");
114 va_start(ap, fmt);
115 vprintf(fmt, ap);
116 va_end(ap);
117 printf(":\n%*s no possibilities available\n",
118 solver_recurse_depth*4, "");
119 }
120 #endif
121 return -1;
122 }
123
124 return 0;
125 }
126
127 struct latin_solver_scratch {
128 unsigned char *grid, *rowidx, *colidx, *set;
129 int *neighbours, *bfsqueue;
130 #ifdef STANDALONE_SOLVER
131 int *bfsprev;
132 #endif
133 };
134
135 int latin_solver_set(struct latin_solver *solver,
136 struct latin_solver_scratch *scratch,
137 int start, int step1, int step2
138 #ifdef STANDALONE_SOLVER
139 , char *fmt, ...
140 #endif
141 )
142 {
143 int o = solver->o;
144 int i, j, n, count;
145 unsigned char *grid = scratch->grid;
146 unsigned char *rowidx = scratch->rowidx;
147 unsigned char *colidx = scratch->colidx;
148 unsigned char *set = scratch->set;
149
150 /*
151 * We are passed a o-by-o matrix of booleans. Our first job
152 * is to winnow it by finding any definite placements - i.e.
153 * any row with a solitary 1 - and discarding that row and the
154 * column containing the 1.
155 */
156 memset(rowidx, TRUE, o);
157 memset(colidx, TRUE, o);
158 for (i = 0; i < o; i++) {
159 int count = 0, first = -1;
160 for (j = 0; j < o; j++)
161 if (solver->cube[start+i*step1+j*step2])
162 first = j, count++;
163
164 if (count == 0) return -1;
165 if (count == 1)
166 rowidx[i] = colidx[first] = FALSE;
167 }
168
169 /*
170 * Convert each of rowidx/colidx from a list of 0s and 1s to a
171 * list of the indices of the 1s.
172 */
173 for (i = j = 0; i < o; i++)
174 if (rowidx[i])
175 rowidx[j++] = i;
176 n = j;
177 for (i = j = 0; i < o; i++)
178 if (colidx[i])
179 colidx[j++] = i;
180 assert(n == j);
181
182 /*
183 * And create the smaller matrix.
184 */
185 for (i = 0; i < n; i++)
186 for (j = 0; j < n; j++)
187 grid[i*o+j] = solver->cube[start+rowidx[i]*step1+colidx[j]*step2];
188
189 /*
190 * Having done that, we now have a matrix in which every row
191 * has at least two 1s in. Now we search to see if we can find
192 * a rectangle of zeroes (in the set-theoretic sense of
193 * `rectangle', i.e. a subset of rows crossed with a subset of
194 * columns) whose width and height add up to n.
195 */
196
197 memset(set, 0, n);
198 count = 0;
199 while (1) {
200 /*
201 * We have a candidate set. If its size is <=1 or >=n-1
202 * then we move on immediately.
203 */
204 if (count > 1 && count < n-1) {
205 /*
206 * The number of rows we need is n-count. See if we can
207 * find that many rows which each have a zero in all
208 * the positions listed in `set'.
209 */
210 int rows = 0;
211 for (i = 0; i < n; i++) {
212 int ok = TRUE;
213 for (j = 0; j < n; j++)
214 if (set[j] && grid[i*o+j]) {
215 ok = FALSE;
216 break;
217 }
218 if (ok)
219 rows++;
220 }
221
222 /*
223 * We expect never to be able to get _more_ than
224 * n-count suitable rows: this would imply that (for
225 * example) there are four numbers which between them
226 * have at most three possible positions, and hence it
227 * indicates a faulty deduction before this point or
228 * even a bogus clue.
229 */
230 if (rows > n - count) {
231 #ifdef STANDALONE_SOLVER
232 if (solver_show_working) {
233 va_list ap;
234 printf("%*s", solver_recurse_depth*4,
235 "");
236 va_start(ap, fmt);
237 vprintf(fmt, ap);
238 va_end(ap);
239 printf(":\n%*s contradiction reached\n",
240 solver_recurse_depth*4, "");
241 }
242 #endif
243 return -1;
244 }
245
246 if (rows >= n - count) {
247 int progress = FALSE;
248
249 /*
250 * We've got one! Now, for each row which _doesn't_
251 * satisfy the criterion, eliminate all its set
252 * bits in the positions _not_ listed in `set'.
253 * Return +1 (meaning progress has been made) if we
254 * successfully eliminated anything at all.
255 *
256 * This involves referring back through
257 * rowidx/colidx in order to work out which actual
258 * positions in the cube to meddle with.
259 */
260 for (i = 0; i < n; i++) {
261 int ok = TRUE;
262 for (j = 0; j < n; j++)
263 if (set[j] && grid[i*o+j]) {
264 ok = FALSE;
265 break;
266 }
267 if (!ok) {
268 for (j = 0; j < n; j++)
269 if (!set[j] && grid[i*o+j]) {
270 int fpos = (start+rowidx[i]*step1+
271 colidx[j]*step2);
272 #ifdef STANDALONE_SOLVER
273 if (solver_show_working) {
274 int px, py, pn;
275
276 if (!progress) {
277 va_list ap;
278 printf("%*s", solver_recurse_depth*4,
279 "");
280 va_start(ap, fmt);
281 vprintf(fmt, ap);
282 va_end(ap);
283 printf(":\n");
284 }
285
286 pn = 1 + fpos % o;
287 py = fpos / o;
288 px = py / o;
289 py %= o;
290
291 printf("%*s ruling out %d at (%d,%d)\n",
292 solver_recurse_depth*4, "",
293 pn, px, YUNTRANS(py));
294 }
295 #endif
296 progress = TRUE;
297 solver->cube[fpos] = FALSE;
298 }
299 }
300 }
301
302 if (progress) {
303 return +1;
304 }
305 }
306 }
307
308 /*
309 * Binary increment: change the rightmost 0 to a 1, and
310 * change all 1s to the right of it to 0s.
311 */
312 i = n;
313 while (i > 0 && set[i-1])
314 set[--i] = 0, count--;
315 if (i > 0)
316 set[--i] = 1, count++;
317 else
318 break; /* done */
319 }
320
321 return 0;
322 }
323
324 /*
325 * Look for forcing chains. A forcing chain is a path of
326 * pairwise-exclusive squares (i.e. each pair of adjacent squares
327 * in the path are in the same row, column or block) with the
328 * following properties:
329 *
330 * (a) Each square on the path has precisely two possible numbers.
331 *
332 * (b) Each pair of squares which are adjacent on the path share
333 * at least one possible number in common.
334 *
335 * (c) Each square in the middle of the path shares _both_ of its
336 * numbers with at least one of its neighbours (not the same
337 * one with both neighbours).
338 *
339 * These together imply that at least one of the possible number
340 * choices at one end of the path forces _all_ the rest of the
341 * numbers along the path. In order to make real use of this, we
342 * need further properties:
343 *
344 * (c) Ruling out some number N from the square at one end
345 * of the path forces the square at the other end to
346 * take number N.
347 *
348 * (d) The two end squares are both in line with some third
349 * square.
350 *
351 * (e) That third square currently has N as a possibility.
352 *
353 * If we can find all of that lot, we can deduce that at least one
354 * of the two ends of the forcing chain has number N, and that
355 * therefore the mutually adjacent third square does not.
356 *
357 * To find forcing chains, we're going to start a bfs at each
358 * suitable square, once for each of its two possible numbers.
359 */
360 int latin_solver_forcing(struct latin_solver *solver,
361 struct latin_solver_scratch *scratch)
362 {
363 int o = solver->o;
364 int *bfsqueue = scratch->bfsqueue;
365 #ifdef STANDALONE_SOLVER
366 int *bfsprev = scratch->bfsprev;
367 #endif
368 unsigned char *number = scratch->grid;
369 int *neighbours = scratch->neighbours;
370 int x, y;
371
372 for (y = 0; y < o; y++)
373 for (x = 0; x < o; x++) {
374 int count, t, n;
375
376 /*
377 * If this square doesn't have exactly two candidate
378 * numbers, don't try it.
379 *
380 * In this loop we also sum the candidate numbers,
381 * which is a nasty hack to allow us to quickly find
382 * `the other one' (since we will shortly know there
383 * are exactly two).
384 */
385 for (count = t = 0, n = 1; n <= o; n++)
386 if (cube(x, y, n))
387 count++, t += n;
388 if (count != 2)
389 continue;
390
391 /*
392 * Now attempt a bfs for each candidate.
393 */
394 for (n = 1; n <= o; n++)
395 if (cube(x, y, n)) {
396 int orign, currn, head, tail;
397
398 /*
399 * Begin a bfs.
400 */
401 orign = n;
402
403 memset(number, o+1, o*o);
404 head = tail = 0;
405 bfsqueue[tail++] = y*o+x;
406 #ifdef STANDALONE_SOLVER
407 bfsprev[y*o+x] = -1;
408 #endif
409 number[y*o+x] = t - n;
410
411 while (head < tail) {
412 int xx, yy, nneighbours, xt, yt, i;
413
414 xx = bfsqueue[head++];
415 yy = xx / o;
416 xx %= o;
417
418 currn = number[yy*o+xx];
419
420 /*
421 * Find neighbours of yy,xx.
422 */
423 nneighbours = 0;
424 for (yt = 0; yt < o; yt++)
425 neighbours[nneighbours++] = yt*o+xx;
426 for (xt = 0; xt < o; xt++)
427 neighbours[nneighbours++] = yy*o+xt;
428
429 /*
430 * Try visiting each of those neighbours.
431 */
432 for (i = 0; i < nneighbours; i++) {
433 int cc, tt, nn;
434
435 xt = neighbours[i] % o;
436 yt = neighbours[i] / o;
437
438 /*
439 * We need this square to not be
440 * already visited, and to include
441 * currn as a possible number.
442 */
443 if (number[yt*o+xt] <= o)
444 continue;
445 if (!cube(xt, yt, currn))
446 continue;
447
448 /*
449 * Don't visit _this_ square a second
450 * time!
451 */
452 if (xt == xx && yt == yy)
453 continue;
454
455 /*
456 * To continue with the bfs, we need
457 * this square to have exactly two
458 * possible numbers.
459 */
460 for (cc = tt = 0, nn = 1; nn <= o; nn++)
461 if (cube(xt, yt, nn))
462 cc++, tt += nn;
463 if (cc == 2) {
464 bfsqueue[tail++] = yt*o+xt;
465 #ifdef STANDALONE_SOLVER
466 bfsprev[yt*o+xt] = yy*o+xx;
467 #endif
468 number[yt*o+xt] = tt - currn;
469 }
470
471 /*
472 * One other possibility is that this
473 * might be the square in which we can
474 * make a real deduction: if it's
475 * adjacent to x,y, and currn is equal
476 * to the original number we ruled out.
477 */
478 if (currn == orign &&
479 (xt == x || yt == y)) {
480 #ifdef STANDALONE_SOLVER
481 if (solver_show_working) {
482 char *sep = "";
483 int xl, yl;
484 printf("%*sforcing chain, %d at ends of ",
485 solver_recurse_depth*4, "", orign);
486 xl = xx;
487 yl = yy;
488 while (1) {
489 printf("%s(%d,%d)", sep, xl,
490 YUNTRANS(yl));
491 xl = bfsprev[yl*o+xl];
492 if (xl < 0)
493 break;
494 yl = xl / o;
495 xl %= o;
496 sep = "-";
497 }
498 printf("\n%*s ruling out %d at (%d,%d)\n",
499 solver_recurse_depth*4, "",
500 orign, xt, YUNTRANS(yt));
501 }
502 #endif
503 cube(xt, yt, orign) = FALSE;
504 return 1;
505 }
506 }
507 }
508 }
509 }
510
511 return 0;
512 }
513
514 struct latin_solver_scratch *latin_solver_new_scratch(struct latin_solver *solver)
515 {
516 struct latin_solver_scratch *scratch = snew(struct latin_solver_scratch);
517 int o = solver->o;
518 scratch->grid = snewn(o*o, unsigned char);
519 scratch->rowidx = snewn(o, unsigned char);
520 scratch->colidx = snewn(o, unsigned char);
521 scratch->set = snewn(o, unsigned char);
522 scratch->neighbours = snewn(3*o, int);
523 scratch->bfsqueue = snewn(o*o, int);
524 #ifdef STANDALONE_SOLVER
525 scratch->bfsprev = snewn(o*o, int);
526 #endif
527 return scratch;
528 }
529
530 void latin_solver_free_scratch(struct latin_solver_scratch *scratch)
531 {
532 #ifdef STANDALONE_SOLVER
533 sfree(scratch->bfsprev);
534 #endif
535 sfree(scratch->bfsqueue);
536 sfree(scratch->neighbours);
537 sfree(scratch->set);
538 sfree(scratch->colidx);
539 sfree(scratch->rowidx);
540 sfree(scratch->grid);
541 sfree(scratch);
542 }
543
544 void latin_solver_alloc(struct latin_solver *solver, digit *grid, int o)
545 {
546 int x, y;
547
548 solver->o = o;
549 solver->cube = snewn(o*o*o, unsigned char);
550 solver->grid = grid; /* write straight back to the input */
551 memset(solver->cube, TRUE, o*o*o);
552
553 solver->row = snewn(o*o, unsigned char);
554 solver->col = snewn(o*o, unsigned char);
555 memset(solver->row, FALSE, o*o);
556 memset(solver->col, FALSE, o*o);
557
558 for (x = 0; x < o; x++)
559 for (y = 0; y < o; y++)
560 if (grid[y*o+x])
561 latin_solver_place(solver, x, YTRANS(y), grid[y*o+x]);
562 }
563
564 void latin_solver_free(struct latin_solver *solver)
565 {
566 sfree(solver->cube);
567 sfree(solver->row);
568 sfree(solver->col);
569 }
570
571 int latin_solver_diff_simple(struct latin_solver *solver)
572 {
573 int x, y, n, ret, o = solver->o;
574 /*
575 * Row-wise positional elimination.
576 */
577 for (y = 0; y < o; y++)
578 for (n = 1; n <= o; n++)
579 if (!solver->row[y*o+n-1]) {
580 ret = latin_solver_elim(solver, cubepos(0,y,n), o*o
581 #ifdef STANDALONE_SOLVER
582 , "positional elimination,"
583 " %d in row %d", n, YUNTRANS(y)
584 #endif
585 );
586 if (ret != 0) return ret;
587 }
588 /*
589 * Column-wise positional elimination.
590 */
591 for (x = 0; x < o; x++)
592 for (n = 1; n <= o; n++)
593 if (!solver->col[x*o+n-1]) {
594 ret = latin_solver_elim(solver, cubepos(x,0,n), o
595 #ifdef STANDALONE_SOLVER
596 , "positional elimination,"
597 " %d in column %d", n, x
598 #endif
599 );
600 if (ret != 0) return ret;
601 }
602
603 /*
604 * Numeric elimination.
605 */
606 for (x = 0; x < o; x++)
607 for (y = 0; y < o; y++)
608 if (!solver->grid[YUNTRANS(y)*o+x]) {
609 ret = latin_solver_elim(solver, cubepos(x,y,1), 1
610 #ifdef STANDALONE_SOLVER
611 , "numeric elimination at (%d,%d)", x,
612 YUNTRANS(y)
613 #endif
614 );
615 if (ret != 0) return ret;
616 }
617 return 0;
618 }
619
620 int latin_solver_diff_set(struct latin_solver *solver,
621 struct latin_solver_scratch *scratch,
622 int extreme)
623 {
624 int x, y, n, ret, o = solver->o;
625
626 if (!extreme) {
627 /*
628 * Row-wise set elimination.
629 */
630 for (y = 0; y < o; y++) {
631 ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
632 #ifdef STANDALONE_SOLVER
633 , "set elimination, row %d", YUNTRANS(y)
634 #endif
635 );
636 if (ret != 0) return ret;
637 }
638 /*
639 * Column-wise set elimination.
640 */
641 for (x = 0; x < o; x++) {
642 ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
643 #ifdef STANDALONE_SOLVER
644 , "set elimination, column %d", x
645 #endif
646 );
647 if (ret != 0) return ret;
648 }
649 } else {
650 /*
651 * Row-vs-column set elimination on a single number
652 * (much tricker for a human to do!)
653 */
654 for (n = 1; n <= o; n++) {
655 ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
656 #ifdef STANDALONE_SOLVER
657 , "positional set elimination, number %d", n
658 #endif
659 );
660 if (ret != 0) return ret;
661 }
662 }
663 return 0;
664 }
665
666 /*
667 * Returns:
668 * 0 for 'didn't do anything' implying it was already solved.
669 * -1 for 'impossible' (no solution)
670 * 1 for 'single solution'
671 * >1 for 'multiple solutions' (you don't get to know how many, and
672 * the first such solution found will be set.
673 *
674 * and this function may well assert if given an impossible board.
675 */
676 static int latin_solver_recurse
677 (struct latin_solver *solver, int diff_simple, int diff_set_0,
678 int diff_set_1, int diff_forcing, int diff_recursive,
679 usersolver_t const *usersolvers, void *ctx,
680 ctxnew_t ctxnew, ctxfree_t ctxfree)
681 {
682 int best, bestcount;
683 int o = solver->o, x, y, n;
684
685 best = -1;
686 bestcount = o+1;
687
688 for (y = 0; y < o; y++)
689 for (x = 0; x < o; x++)
690 if (!solver->grid[y*o+x]) {
691 int count;
692
693 /*
694 * An unfilled square. Count the number of
695 * possible digits in it.
696 */
697 count = 0;
698 for (n = 1; n <= o; n++)
699 if (cube(x,YTRANS(y),n))
700 count++;
701
702 /*
703 * We should have found any impossibilities
704 * already, so this can safely be an assert.
705 */
706 assert(count > 1);
707
708 if (count < bestcount) {
709 bestcount = count;
710 best = y*o+x;
711 }
712 }
713
714 if (best == -1)
715 /* we were complete already. */
716 return 0;
717 else {
718 int i, j;
719 digit *list, *ingrid, *outgrid;
720 int diff = diff_impossible; /* no solution found yet */
721
722 /*
723 * Attempt recursion.
724 */
725 y = best / o;
726 x = best % o;
727
728 list = snewn(o, digit);
729 ingrid = snewn(o*o, digit);
730 outgrid = snewn(o*o, digit);
731 memcpy(ingrid, solver->grid, o*o);
732
733 /* Make a list of the possible digits. */
734 for (j = 0, n = 1; n <= o; n++)
735 if (cube(x,YTRANS(y),n))
736 list[j++] = n;
737
738 #ifdef STANDALONE_SOLVER
739 if (solver_show_working) {
740 char *sep = "";
741 printf("%*srecursing on (%d,%d) [",
742 solver_recurse_depth*4, "", x, y);
743 for (i = 0; i < j; i++) {
744 printf("%s%d", sep, list[i]);
745 sep = " or ";
746 }
747 printf("]\n");
748 }
749 #endif
750
751 /*
752 * And step along the list, recursing back into the
753 * main solver at every stage.
754 */
755 for (i = 0; i < j; i++) {
756 int ret;
757 void *newctx;
758
759 memcpy(outgrid, ingrid, o*o);
760 outgrid[y*o+x] = list[i];
761
762 #ifdef STANDALONE_SOLVER
763 if (solver_show_working)
764 printf("%*sguessing %d at (%d,%d)\n",
765 solver_recurse_depth*4, "", list[i], x, y);
766 solver_recurse_depth++;
767 #endif
768
769 if (ctxnew) {
770 newctx = ctxnew(ctx);
771 } else {
772 newctx = ctx;
773 }
774 ret = latin_solver(outgrid, o, diff_recursive,
775 diff_simple, diff_set_0, diff_set_1,
776 diff_forcing, diff_recursive,
777 usersolvers, newctx, ctxnew, ctxfree);
778 if (ctxnew)
779 ctxfree(newctx);
780
781 #ifdef STANDALONE_SOLVER
782 solver_recurse_depth--;
783 if (solver_show_working) {
784 printf("%*sretracting %d at (%d,%d)\n",
785 solver_recurse_depth*4, "", list[i], x, y);
786 }
787 #endif
788 /* we recurse as deep as we can, so we should never find
789 * find ourselves giving up on a puzzle without declaring it
790 * impossible. */
791 assert(ret != diff_unfinished);
792
793 /*
794 * If we have our first solution, copy it into the
795 * grid we will return.
796 */
797 if (diff == diff_impossible && ret != diff_impossible)
798 memcpy(solver->grid, outgrid, o*o);
799
800 if (ret == diff_ambiguous)
801 diff = diff_ambiguous;
802 else if (ret == diff_impossible)
803 /* do not change our return value */;
804 else {
805 /* the recursion turned up exactly one solution */
806 if (diff == diff_impossible)
807 diff = diff_recursive;
808 else
809 diff = diff_ambiguous;
810 }
811
812 /*
813 * As soon as we've found more than one solution,
814 * give up immediately.
815 */
816 if (diff == diff_ambiguous)
817 break;
818 }
819
820 sfree(outgrid);
821 sfree(ingrid);
822 sfree(list);
823
824 if (diff == diff_impossible)
825 return -1;
826 else if (diff == diff_ambiguous)
827 return 2;
828 else {
829 assert(diff == diff_recursive);
830 return 1;
831 }
832 }
833 }
834
835 int latin_solver_main(struct latin_solver *solver, int maxdiff,
836 int diff_simple, int diff_set_0, int diff_set_1,
837 int diff_forcing, int diff_recursive,
838 usersolver_t const *usersolvers, void *ctx,
839 ctxnew_t ctxnew, ctxfree_t ctxfree)
840 {
841 struct latin_solver_scratch *scratch = latin_solver_new_scratch(solver);
842 int ret, diff = diff_simple;
843
844 assert(maxdiff <= diff_recursive);
845 /*
846 * Now loop over the grid repeatedly trying all permitted modes
847 * of reasoning. The loop terminates if we complete an
848 * iteration without making any progress; we then return
849 * failure or success depending on whether the grid is full or
850 * not.
851 */
852 while (1) {
853 int i;
854
855 cont:
856
857 latin_solver_debug(solver->cube, solver->o);
858
859 for (i = 0; i <= maxdiff; i++) {
860 if (usersolvers[i])
861 ret = usersolvers[i](solver, ctx);
862 else
863 ret = 0;
864 if (ret == 0 && i == diff_simple)
865 ret = latin_solver_diff_simple(solver);
866 if (ret == 0 && i == diff_set_0)
867 ret = latin_solver_diff_set(solver, scratch, 0);
868 if (ret == 0 && i == diff_set_1)
869 ret = latin_solver_diff_set(solver, scratch, 1);
870 if (ret == 0 && i == diff_forcing)
871 ret = latin_solver_forcing(solver, scratch);
872
873 if (ret < 0) {
874 diff = diff_impossible;
875 goto got_result;
876 } else if (ret > 0) {
877 diff = max(diff, i);
878 goto cont;
879 }
880 }
881
882 /*
883 * If we reach here, we have made no deductions in this
884 * iteration, so the algorithm terminates.
885 */
886 break;
887 }
888
889 /*
890 * Last chance: if we haven't fully solved the puzzle yet, try
891 * recursing based on guesses for a particular square. We pick
892 * one of the most constrained empty squares we can find, which
893 * has the effect of pruning the search tree as much as
894 * possible.
895 */
896 if (maxdiff == diff_recursive) {
897 int nsol = latin_solver_recurse(solver,
898 diff_simple, diff_set_0, diff_set_1,
899 diff_forcing, diff_recursive,
900 usersolvers, ctx, ctxnew, ctxfree);
901 if (nsol < 0) diff = diff_impossible;
902 else if (nsol == 1) diff = diff_recursive;
903 else if (nsol > 1) diff = diff_ambiguous;
904 /* if nsol == 0 then we were complete anyway
905 * (and thus don't need to change diff) */
906 } else {
907 /*
908 * We're forbidden to use recursion, so we just see whether
909 * our grid is fully solved, and return diff_unfinished
910 * otherwise.
911 */
912 int x, y, o = solver->o;
913
914 for (y = 0; y < o; y++)
915 for (x = 0; x < o; x++)
916 if (!solver->grid[y*o+x])
917 diff = diff_unfinished;
918 }
919
920 got_result:
921
922 #ifdef STANDALONE_SOLVER
923 if (solver_show_working)
924 printf("%*s%s found\n",
925 solver_recurse_depth*4, "",
926 diff == diff_impossible ? "no solution (impossible)" :
927 diff == diff_unfinished ? "no solution (unfinished)" :
928 diff == diff_ambiguous ? "multiple solutions" :
929 "one solution");
930 #endif
931
932 latin_solver_free_scratch(scratch);
933
934 return diff;
935 }
936
937 int latin_solver(digit *grid, int o, int maxdiff,
938 int diff_simple, int diff_set_0, int diff_set_1,
939 int diff_forcing, int diff_recursive,
940 usersolver_t const *usersolvers, void *ctx,
941 ctxnew_t ctxnew, ctxfree_t ctxfree)
942 {
943 struct latin_solver solver;
944 int diff;
945
946 latin_solver_alloc(&solver, grid, o);
947 diff = latin_solver_main(&solver, maxdiff,
948 diff_simple, diff_set_0, diff_set_1,
949 diff_forcing, diff_recursive,
950 usersolvers, ctx, ctxnew, ctxfree);
951 latin_solver_free(&solver);
952 return diff;
953 }
954
955 void latin_solver_debug(unsigned char *cube, int o)
956 {
957 #ifdef STANDALONE_SOLVER
958 if (solver_show_working > 1) {
959 struct latin_solver ls, *solver = &ls;
960 char *dbg;
961 int x, y, i, c = 0;
962
963 ls.cube = cube; ls.o = o; /* for cube() to work */
964
965 dbg = snewn(3*o*o*o, char);
966 for (y = 0; y < o; y++) {
967 for (x = 0; x < o; x++) {
968 for (i = 1; i <= o; i++) {
969 if (cube(x,y,i))
970 dbg[c++] = i + '0';
971 else
972 dbg[c++] = '.';
973 }
974 dbg[c++] = ' ';
975 }
976 dbg[c++] = '\n';
977 }
978 dbg[c++] = '\n';
979 dbg[c++] = '\0';
980
981 printf("%s", dbg);
982 sfree(dbg);
983 }
984 #endif
985 }
986
987 void latin_debug(digit *sq, int o)
988 {
989 #ifdef STANDALONE_SOLVER
990 if (solver_show_working) {
991 int x, y;
992
993 for (y = 0; y < o; y++) {
994 for (x = 0; x < o; x++) {
995 printf("%2d ", sq[y*o+x]);
996 }
997 printf("\n");
998 }
999 printf("\n");
1000 }
1001 #endif
1002 }
1003
1004 /* --------------------------------------------------------
1005 * Generation.
1006 */
1007
1008 digit *latin_generate(int o, random_state *rs)
1009 {
1010 digit *sq;
1011 int *edges, *backedges, *capacity, *flow;
1012 void *scratch;
1013 int ne, scratchsize;
1014 int i, j, k;
1015 digit *row, *col, *numinv, *num;
1016
1017 /*
1018 * To efficiently generate a latin square in such a way that
1019 * all possible squares are possible outputs from the function,
1020 * we make use of a theorem which states that any r x n latin
1021 * rectangle, with r < n, can be extended into an (r+1) x n
1022 * latin rectangle. In other words, we can reliably generate a
1023 * latin square row by row, by at every stage writing down any
1024 * row at all which doesn't conflict with previous rows, and
1025 * the theorem guarantees that we will never have to backtrack.
1026 *
1027 * To find a viable row at each stage, we can make use of the
1028 * support functions in maxflow.c.
1029 */
1030
1031 sq = snewn(o*o, digit);
1032
1033 /*
1034 * In case this method of generation introduces a really subtle
1035 * top-to-bottom directional bias, we'll generate the rows in
1036 * random order.
1037 */
1038 row = snewn(o, digit);
1039 col = snewn(o, digit);
1040 numinv = snewn(o, digit);
1041 num = snewn(o, digit);
1042 for (i = 0; i < o; i++)
1043 row[i] = i;
1044 shuffle(row, i, sizeof(*row), rs);
1045
1046 /*
1047 * Set up the infrastructure for the maxflow algorithm.
1048 */
1049 scratchsize = maxflow_scratch_size(o * 2 + 2);
1050 scratch = smalloc(scratchsize);
1051 backedges = snewn(o*o + 2*o, int);
1052 edges = snewn((o*o + 2*o) * 2, int);
1053 capacity = snewn(o*o + 2*o, int);
1054 flow = snewn(o*o + 2*o, int);
1055 /* Set up the edge array, and the initial capacities. */
1056 ne = 0;
1057 for (i = 0; i < o; i++) {
1058 /* Each LHS vertex is connected to all RHS vertices. */
1059 for (j = 0; j < o; j++) {
1060 edges[ne*2] = i;
1061 edges[ne*2+1] = j+o;
1062 /* capacity for this edge is set later on */
1063 ne++;
1064 }
1065 }
1066 for (i = 0; i < o; i++) {
1067 /* Each RHS vertex is connected to the distinguished sink vertex. */
1068 edges[ne*2] = i+o;
1069 edges[ne*2+1] = o*2+1;
1070 capacity[ne] = 1;
1071 ne++;
1072 }
1073 for (i = 0; i < o; i++) {
1074 /* And the distinguished source vertex connects to each LHS vertex. */
1075 edges[ne*2] = o*2;
1076 edges[ne*2+1] = i;
1077 capacity[ne] = 1;
1078 ne++;
1079 }
1080 assert(ne == o*o + 2*o);
1081 /* Now set up backedges. */
1082 maxflow_setup_backedges(ne, edges, backedges);
1083
1084 /*
1085 * Now generate each row of the latin square.
1086 */
1087 for (i = 0; i < o; i++) {
1088 /*
1089 * To prevent maxflow from behaving deterministically, we
1090 * separately permute the columns and the digits for the
1091 * purposes of the algorithm, differently for every row.
1092 */
1093 for (j = 0; j < o; j++)
1094 col[j] = num[j] = j;
1095 shuffle(col, j, sizeof(*col), rs);
1096 shuffle(num, j, sizeof(*num), rs);
1097 /* We need the num permutation in both forward and inverse forms. */
1098 for (j = 0; j < o; j++)
1099 numinv[num[j]] = j;
1100
1101 /*
1102 * Set up the capacities for the maxflow run, by examining
1103 * the existing latin square.
1104 */
1105 for (j = 0; j < o*o; j++)
1106 capacity[j] = 1;
1107 for (j = 0; j < i; j++)
1108 for (k = 0; k < o; k++) {
1109 int n = num[sq[row[j]*o + col[k]] - 1];
1110 capacity[k*o+n] = 0;
1111 }
1112
1113 /*
1114 * Run maxflow.
1115 */
1116 j = maxflow_with_scratch(scratch, o*2+2, 2*o, 2*o+1, ne,
1117 edges, backedges, capacity, flow, NULL);
1118 assert(j == o); /* by the above theorem, this must have succeeded */
1119
1120 /*
1121 * And examine the flow array to pick out the new row of
1122 * the latin square.
1123 */
1124 for (j = 0; j < o; j++) {
1125 for (k = 0; k < o; k++) {
1126 if (flow[j*o+k])
1127 break;
1128 }
1129 assert(k < o);
1130 sq[row[i]*o + col[j]] = numinv[k] + 1;
1131 }
1132 }
1133
1134 /*
1135 * Done. Free our internal workspaces...
1136 */
1137 sfree(flow);
1138 sfree(capacity);
1139 sfree(edges);
1140 sfree(backedges);
1141 sfree(scratch);
1142 sfree(numinv);
1143 sfree(num);
1144 sfree(col);
1145 sfree(row);
1146
1147 /*
1148 * ... and return our completed latin square.
1149 */
1150 return sq;
1151 }
1152
1153 /* --------------------------------------------------------
1154 * Checking.
1155 */
1156
1157 typedef struct lcparams {
1158 digit elt;
1159 int count;
1160 } lcparams;
1161
1162 static int latin_check_cmp(void *v1, void *v2)
1163 {
1164 lcparams *lc1 = (lcparams *)v1;
1165 lcparams *lc2 = (lcparams *)v2;
1166
1167 if (lc1->elt < lc2->elt) return -1;
1168 if (lc1->elt > lc2->elt) return 1;
1169 return 0;
1170 }
1171
1172 #define ELT(sq,x,y) (sq[((y)*order)+(x)])
1173
1174 /* returns non-zero if sq is not a latin square. */
1175 int latin_check(digit *sq, int order)
1176 {
1177 tree234 *dict = newtree234(latin_check_cmp);
1178 int c, r;
1179 int ret = 0;
1180 lcparams *lcp, lc, *aret;
1181
1182 /* Use a tree234 as a simple hash table, go through the square
1183 * adding elements as we go or incrementing their counts. */
1184 for (c = 0; c < order; c++) {
1185 for (r = 0; r < order; r++) {
1186 lc.elt = ELT(sq, c, r); lc.count = 0;
1187 lcp = find234(dict, &lc, NULL);
1188 if (!lcp) {
1189 lcp = snew(lcparams);
1190 lcp->elt = ELT(sq, c, r);
1191 lcp->count = 1;
1192 aret = add234(dict, lcp);
1193 assert(aret == lcp);
1194 } else {
1195 lcp->count++;
1196 }
1197 }
1198 }
1199
1200 /* There should be precisely 'order' letters in the alphabet,
1201 * each occurring 'order' times (making the OxO tree) */
1202 if (count234(dict) != order) ret = 1;
1203 else {
1204 for (c = 0; (lcp = index234(dict, c)) != NULL; c++) {
1205 if (lcp->count != order) ret = 1;
1206 }
1207 }
1208 for (c = 0; (lcp = index234(dict, c)) != NULL; c++)
1209 sfree(lcp);
1210 freetree234(dict);
1211
1212 return ret;
1213 }
1214
1215
1216 /* --------------------------------------------------------
1217 * Testing (and printing).
1218 */
1219
1220 #ifdef STANDALONE_LATIN_TEST
1221
1222 #include <stdio.h>
1223 #include <time.h>
1224
1225 const char *quis;
1226
1227 static void latin_print(digit *sq, int order)
1228 {
1229 int x, y;
1230
1231 for (y = 0; y < order; y++) {
1232 for (x = 0; x < order; x++) {
1233 printf("%2u ", ELT(sq, x, y));
1234 }
1235 printf("\n");
1236 }
1237 printf("\n");
1238 }
1239
1240 static void gen(int order, random_state *rs, int debug)
1241 {
1242 digit *sq;
1243
1244 solver_show_working = debug;
1245
1246 sq = latin_generate(order, rs);
1247 latin_print(sq, order);
1248 if (latin_check(sq, order)) {
1249 fprintf(stderr, "Square is not a latin square!");
1250 exit(1);
1251 }
1252
1253 sfree(sq);
1254 }
1255
1256 void test_soak(int order, random_state *rs)
1257 {
1258 digit *sq;
1259 int n = 0;
1260 time_t tt_start, tt_now, tt_last;
1261
1262 solver_show_working = 0;
1263 tt_now = tt_start = time(NULL);
1264
1265 while(1) {
1266 sq = latin_generate(order, rs);
1267 sfree(sq);
1268 n++;
1269
1270 tt_last = time(NULL);
1271 if (tt_last > tt_now) {
1272 tt_now = tt_last;
1273 printf("%d total, %3.1f/s\n", n,
1274 (double)n / (double)(tt_now - tt_start));
1275 }
1276 }
1277 }
1278
1279 void usage_exit(const char *msg)
1280 {
1281 if (msg)
1282 fprintf(stderr, "%s: %s\n", quis, msg);
1283 fprintf(stderr, "Usage: %s [--seed SEED] --soak <params> | [game_id [game_id ...]]\n", quis);
1284 exit(1);
1285 }
1286
1287 int main(int argc, char *argv[])
1288 {
1289 int i, soak = 0;
1290 random_state *rs;
1291 time_t seed = time(NULL);
1292
1293 quis = argv[0];
1294 while (--argc > 0) {
1295 const char *p = *++argv;
1296 if (!strcmp(p, "--soak"))
1297 soak = 1;
1298 else if (!strcmp(p, "--seed")) {
1299 if (argc == 0)
1300 usage_exit("--seed needs an argument");
1301 seed = (time_t)atoi(*++argv);
1302 argc--;
1303 } else if (*p == '-')
1304 usage_exit("unrecognised option");
1305 else
1306 break; /* finished options */
1307 }
1308
1309 rs = random_new((void*)&seed, sizeof(time_t));
1310
1311 if (soak == 1) {
1312 if (argc != 1) usage_exit("only one argument for --soak");
1313 test_soak(atoi(*argv), rs);
1314 } else {
1315 if (argc > 0) {
1316 for (i = 0; i < argc; i++) {
1317 gen(atoi(*argv++), rs, 1);
1318 }
1319 } else {
1320 while (1) {
1321 i = random_upto(rs, 20) + 1;
1322 gen(i, rs, 0);
1323 }
1324 }
1325 }
1326 random_free(rs);
1327 return 0;
1328 }
1329
1330 #endif
1331
1332 /* vim: set shiftwidth=4 tabstop=8: */