mm.c (rate): Simplify the algorithm.
[mm] / mm.c
CommitLineData
7a869046
MW
1/* -*-c-*-
2 *
7a869046
MW
3 * Simple mastermind game
4 *
5 * (c) 2006 Mark Wooding
6 */
7
1c91cf29 8/*----- Licensing notice --------------------------------------------------*
7a869046
MW
9 *
10 * This file is part of mm: a simple Mastermind game.
11 *
12 * mm is free software; you can redistribute it and/or modify
13 * it under the terms of the GNU General Public License as published by
14 * the Free Software Foundation; either version 2 of the License, or
15 * (at your option) any later version.
1c91cf29 16 *
7a869046
MW
17 * mm is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty of
19 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 * GNU General Public License for more details.
1c91cf29 21 *
7a869046
MW
22 * You should have received a copy of the GNU General Public License
23 * along with mm; if not, write to the Free Software Foundation,
24 * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25 */
26
27/*----- Header files ------------------------------------------------------*/
28
0fd4f375 29#include <assert.h>
7a869046
MW
30#include <ctype.h>
31#include <stdio.h>
32#include <stdlib.h>
33#include <string.h>
34#include <time.h>
35
36#include <mLib/alloc.h>
0fd4f375 37#include <mLib/darray.h>
7a869046
MW
38#include <mLib/mdwopt.h>
39#include <mLib/quis.h>
40#include <mLib/report.h>
41#include <mLib/sub.h>
42
43/*----- Data structures ---------------------------------------------------*/
44
45/* --- Digits --- *
46 *
47 * The symbols which make up the code to be guessed.
48 */
49
50typedef unsigned char dig;
51
52/* --- The game parameters --- */
53
54typedef struct mm {
55 dig k; /* Number of symbols in the code */
56 dig n; /* Number of distinct symbols */
57} mm;
58
59/*----- Rating guesses ----------------------------------------------------*/
60
61/* --- Algorithm --- *
62 *
63 * Rating guesses efficiently is quite important.
64 *
65 * The rate context structure contains a copy of the game parameters, and
66 * three arrays, @v@, @s@ and @t@ allocated from the same @malloc@ed block:
67 *
68 * * %$v_i$% counts occurrences of the symbol %$i$% in the code.
69 * * %$s$% is a copy of the code.
70 * * %$t$% is temporary work space for the rating function.
71 *
72 * The rating function works by taking a pass over the guess, computing a
44e8ae5a
MW
73 * count table %$v'$%. The black count %$b$% is then the number of such
74 * matches, and the white count is given by
7a869046 75 *
44e8ae5a 76 * %$w = \displaystyle \sum_{0<i\le n} \min(v_i, v'_i) - b.$%
7a869046
MW
77 *
78 * Thus, the work is %$O(k + n)$%, rather than the %$O(k^2)$% for a
79 * %%na\"\i{}ve%% implementation.
80 */
81
82typedef struct ratectx {
83 mm m;
84 dig *v;
85 dig *s;
86 dig *t;
87} ratectx;
88
89static ratectx *rate_alloc(const mm *m)
90{
91 ratectx *r;
92 dig *v;
93
94 r = CREATE(ratectx);
44e8ae5a 95 v = xmalloc((2 * m->n + m->k) * sizeof(dig));
7a869046
MW
96 r->m = *m;
97 r->v = v;
98 r->s = r->v + m->n;
99 r->t = r->s + m->k;
100 return (r);
101}
102
103static void rate_init(ratectx *r, const dig *s)
104{
105 unsigned i;
106
107 memset(r->v, 0, r->m.n * sizeof(dig));
108 for (i = 0; i < r->m.k; i++)
109 r->v[s[i]]++;
110 memcpy(r->s, s, r->m.k * sizeof(dig));
111}
112
113static ratectx *rate_new(const mm *m, const dig *s)
114{
115 ratectx *r = rate_alloc(m);
409fd535 116
7a869046
MW
117 rate_init(r, s);
118 return (r);
119}
120
121static void rate(const ratectx *r, const dig *g, unsigned *b, unsigned *w)
122{
123 unsigned i;
124 unsigned k = r->m.k, n = r->m.n;
125 dig *v = r->t;
7a869046
MW
126 const dig *s = r->s;
127 unsigned bb = 0, ww = 0;
128
129 memset(v, 0, n * sizeof(dig));
7a869046 130 for (i = 0; i < k; i++) {
44e8ae5a
MW
131 if (g[i] == s[i]) bb++;
132 v[g[i]]++;
7a869046
MW
133 }
134 for (i = 0; i < n; i++)
44e8ae5a 135 ww += v[i] < r->v[i] ? v[i] : r->v[i];
7a869046 136 *b = bb;
44e8ae5a 137 *w = ww - bb;
7a869046
MW
138}
139
409fd535 140static void rate_free(ratectx *r) { xfree(r->v); DESTROY(r); }
7a869046
MW
141
142/*----- Computer player ---------------------------------------------------*/
143
144/* --- About the algorithms --- *
145 *
146 * At each stage, we attampt to choose the guess which will give us the most
147 * information, regardless of the outcome. For each guess candidate, we
148 * count the remaining possible codes for each outcome, and choose the
149 * candidate with the least square sum. There are wrinkles.
150 *
151 * Firstly the number of possible guesses is large, and the number of
152 * possible codes is huge too; and our algorithm takes time proportional to
153 * the product of the two. However, a symbol we've never tried before is as
154 * good as any other, so we can narrow the list of candidate guesses by
155 * considering only %%\emph{prototypes}%% where we use only the smallest
156 * untried symbol at any point to represent introducing any new symbol. The
157 * number of initial prototypes is quite small. For the four-symbol game,
158 * they are 0000, 0001, 0011, 0012, 0111, 0112, 0122, and 0123.
159 *
160 * Secondly, when the number of possible codes become small, we want to bias
161 * the guess selection algorithm towards possible codes (so that we win if
162 * we're lucky). Since the algorithm chooses the first guess with the lowest
163 * sum-of-squares value, we simply run through the possible codes before
164 * enumerating the prototype guesses.
165 */
166
167typedef struct cpc {
168 mm m; /* Game parameters */
9ea0417a
MW
169 unsigned f; /* Various flags */
170#define CPCF_QUIET 1u /* Don't produce lots of output */
7a869046
MW
171 dig *s; /* n^k * k */ /* Remaining guesses */
172 size_t ns; /* Number of remaining guesses */
173 dig *bg; /* k */ /* Current best guess */
174 dig *t; /* k */ /* Scratch-space for prototype */
175 double bmax; /* Best guess least-squares score */
176 dig x, bx; /* Next unused symbol index */
177 size_t *v; /* (k + 1)^2 */ /* Bin counts for least-squares */
178 ratectx *r; /* Rate context for search */
179} cpc;
180
181static void print_guess(const mm *m, const dig *d)
182{
183 unsigned k = m->k, i;
184
185 for (i = 0; i < k; i++) {
186 if (i) putchar(' ');
187 printf("%u", d[i]);
188 }
189}
190
191static void dofep(cpc *c, void (*fn)(cpc *c, const dig *g, unsigned x),
192 unsigned k, unsigned n, unsigned i, unsigned x)
193{
194 unsigned j;
195 dig *t = c->t;
196
197 if (i == k)
198 fn(c, c->t, x);
199 else {
200 for (j = 0; j < x; j++) {
201 t[i] = j;
202 dofep(c, fn, k, n, i + 1, x);
203 }
204 if (x < n) {
205 t[i] = x;
206 dofep(c, fn, k, n, i + 1, x + 1);
207 }
208 }
209}
210
211static void foreach_proto(cpc *c, void (*fn)(cpc *c,
212 const dig *g,
213 unsigned x))
1c91cf29 214{
7a869046 215 unsigned k = c->m.k, n = c->m.n;
409fd535 216
7a869046
MW
217 dofep(c, fn, k, n, 0, c->x);
218}
219
220static void try_guess(cpc *c, const dig *g, unsigned x)
221{
222 size_t i;
223 unsigned b, w;
224 const dig *s;
225 unsigned k = c->m.k;
226 size_t *v = c->v;
227 size_t *vp;
228 double max;
229
230 rate_init(c->r, g);
231 memset(v, 0, (k + 1) * (k + 1) * sizeof(size_t));
232 for (i = c->ns, s = c->s; i; i--, s += k) {
233 rate(c->r, s, &b, &w);
234 v[b * (k + 1) + w]++;
235 }
236 max = 0;
237 for (i = (k + 1) * (k + 1), vp = v; i; i--, vp++)
238 max += (double)*vp * (double)*vp;
239 if (c->bmax < 0 || max < c->bmax) {
240 memcpy(c->bg, g, k * sizeof(dig));
241 c->bmax = max;
242 c->bx = x;
243 }
244}
245
246static void best_guess(cpc *c)
247{
248 c->bmax = -1;
249 if (c->ns < 1024) {
250 unsigned k = c->m.k;
251 const dig *s;
252 size_t i;
253
254 for (i = c->ns, s = c->s; i; i--, s += k)
255 try_guess(c, s, c->x);
256 }
257 foreach_proto(c, try_guess);
258 c->x = c->bx;
259}
260
261static void filter_guesses(cpc *c, const dig *g, unsigned b, unsigned w)
262{
263 unsigned k = c->m.k;
264 size_t i;
265 const dig *s;
266 unsigned bb, ww;
267 dig *ss;
268
269 rate_init(c->r, g);
270 for (i = c->ns, s = ss = c->s; i; i--, s += k) {
271 rate(c->r, s, &bb, &ww);
272 if (b == bb && w == ww) {
273 memmove(ss, s, k * sizeof(dig));
274 ss += k;
275 }
276 }
277 c->ns = (ss - c->s) / k;
278}
279
280static size_t ipow(size_t b, size_t x)
281{
282 size_t a = 1;
283 while (x) {
284 if (x & 1)
285 a *= b;
286 b *= b;
287 x >>= 1;
288 }
289 return (a);
290}
291
292static void all_guesses(dig **ss, unsigned k, unsigned n,
293 unsigned i, const dig *b)
294{
295 unsigned j;
296
297 if (i == k) {
298 (*ss) += k;
299 return;
300 }
301 for (j = 0; j < n; j++) {
302 dig *s = *ss;
303 if (i)
304 memcpy(*ss, b, i * sizeof(dig));
305 s[i] = j;
306 all_guesses(ss, k, n, i + 1, s);
307 }
308}
309
9ea0417a
MW
310#define THINK(c, what, how) do { \
311 clock_t _t0 = 0, _t1; \
312 if (!(c->f & CPCF_QUIET)) { \
313 fputs(what "...", stdout); \
314 fflush(stdout); \
315 _t0 = clock(); \
316 } \
7a869046 317 do how while (0); \
9ea0417a
MW
318 if (!(c->f & CPCF_QUIET)) { \
319 _t1 = clock(); \
320 printf(" done (%.2fs)\n", (_t1 - _t0)/(double)CLOCKS_PER_SEC); \
321 } \
7a869046
MW
322} while (0)
323
9ea0417a 324static cpc *cpc_new(const mm *m, unsigned f)
7a869046
MW
325{
326 cpc *c = CREATE(cpc);
9ea0417a
MW
327
328 c->f = f;
7a869046
MW
329 c->m = *m;
330 c->ns = ipow(c->m.n, c->m.k);
331 c->s = xmalloc((c->ns + 2) * c->m.k * sizeof(dig));
332 c->bg = c->s + c->ns * c->m.k;
333 c->t = c->bg + c->m.k;
334 c->x = 0;
335 c->v = xmalloc((c->m.k + 1) * (c->m.k + 1) * sizeof(size_t));
336 c->r = rate_alloc(m);
9ea0417a 337 THINK(c, "Setting up", {
7a869046
MW
338 dig *ss = c->s; all_guesses(&ss, c->m.k, c->m.n, 0, 0);
339 });
340 return (c);
341}
342
343static void cpc_free(cpc *c)
344{
345 xfree(c->s);
346 xfree(c->v);
347 rate_free(c->r);
348 DESTROY(c);
349}
350
351static void cp_rate(void *r, const dig *g, unsigned *b, unsigned *w)
409fd535 352 { rate(r, g, b, w); }
7a869046
MW
353
354static const dig *cp_guess(void *cc)
355{
356 cpc *c = cc;
357
358 if (c->ns == 0) {
9ea0417a
MW
359 if (!(c->f & CPCF_QUIET))
360 printf("Liar! All solutions ruled out.\n");
7a869046
MW
361 return (0);
362 }
363 if (c->ns == 1) {
9ea0417a
MW
364 if (!(c->f & CPCF_QUIET)) {
365 fputs("Done! Solution is ", stdout);
366 print_guess(&c->m, c->s);
367 putchar('\n');
368 }
7a869046
MW
369 return (c->s);
370 }
9ea0417a
MW
371 if (!(c->f & CPCF_QUIET)) {
372 printf("(Possible solutions remaining = %lu)\n",
373 (unsigned long)c->ns);
374 if (c->ns < 32) {
375 const dig *s;
376 size_t i;
377 for (i = c->ns, s = c->s; i; i--, s += c->m.k) {
378 printf(" %2lu: ", (unsigned long)(c->ns - i + 1));
379 print_guess(&c->m, s);
380 putchar('\n');
381 }
7a869046
MW
382 }
383 }
9ea0417a 384 THINK(c, "Pondering", {
7a869046
MW
385 best_guess(c);
386 });
387 return (c->bg);
388}
389
390static void cp_update(void *cc, const dig *g, unsigned b, unsigned w)
391{
392 cpc *c = cc;
9ea0417a
MW
393
394 if (!(c->f & CPCF_QUIET)) {
395 fputs("My guess = ", stdout);
396 print_guess(&c->m, g);
397 printf("; rating = %u black, %u white\n", b, w);
398 }
399 THINK(c, "Filtering", {
7a869046
MW
400 filter_guesses(c, g, b, w);
401 });
402}
403
404/*----- Human player ------------------------------------------------------*/
405
406typedef struct hpc {
407 mm m;
408 dig *t;
409} hpc;
410
411static hpc *hpc_new(const mm *m)
412{
413 hpc *h = CREATE(hpc);
414 h->t = xmalloc(m->k * sizeof(dig));
415 h->m = *m;
416 return (h);
417}
418
419static void hpc_free(hpc *h)
420{
421 xfree(h->t);
422 DESTROY(h);
423}
424
425static void hp_rate(void *mp, const dig *g, unsigned *b, unsigned *w)
426{
427 mm *m = mp;
428 fputs("Guess = ", stdout);
429 print_guess(m, g);
430 printf("; rating: ");
431 fflush(stdout);
432 scanf("%u %u", b, w);
433}
434
435static const dig *hp_guess(void *hh)
436{
437 hpc *h = hh;
438 unsigned i;
439
440 fputs("Your guess: ", stdout);
441 fflush(stdout);
442 for (i = 0; i < h->m.k; i++) {
443 unsigned x;
444 scanf("%u", &x);
445 h->t[i] = x;
446 }
447 return (h->t);
448}
449
450static void hp_update(void *cc, const dig *g, unsigned b, unsigned w)
451{
452 printf("Rating = %u black, %u white\n", b, w);
453}
454
455/*----- Solver player -----------------------------------------------------*/
456
457typedef struct spc {
458 cpc *c;
459 hpc *h;
460 int i;
461} spc;
462
463static spc *spc_new(const mm *m)
464{
465 spc *s = CREATE(spc);
9ea0417a 466 s->c = cpc_new(m, 0);
7a869046
MW
467 s->h = hpc_new(m);
468 s->i = 0;
469 return (s);
470}
471
472static void spc_free(spc *s)
473{
474 cpc_free(s->c);
475 hpc_free(s->h);
476 DESTROY(s);
477}
478
479static const dig *sp_guess(void *ss)
480{
481 spc *s = ss;
482 hpc *h = s->h;
483 unsigned i;
484 int ch;
485
486again:
487 if (s->i)
488 return (cp_guess(s->c));
489
490 fputs("Your guess (dot for end): ", stdout);
491 fflush(stdout);
492 do ch = getchar(); while (isspace(ch));
493 if (!isdigit(ch)) { s->i = 1; goto again; }
494 ungetc(ch, stdin);
495 for (i = 0; i < h->m.k; i++) {
496 unsigned x;
497 scanf("%u", &x);
498 h->t[i] = x;
499 }
1c91cf29 500 return (h->t);
7a869046
MW
501}
502
503static void sp_update(void *ss, const dig *g, unsigned b, unsigned w)
409fd535 504 { spc *s = ss; cp_update(s->c, g, b, w); }
7a869046 505
0fd4f375
MW
506/*----- Full tournament stuff ---------------------------------------------*/
507
508DA_DECL(uint_v, unsigned);
509
510typedef struct allstats {
511 const mm *m;
512 unsigned f;
513#define AF_VERBOSE 1u
514 uint_v gmap;
515 unsigned long g;
516 unsigned long n;
517 clock_t t;
518} allstats;
519
520static void dorunone(allstats *a, dig *s)
521{
522 ratectx *r = rate_new(a->m, s);
523 clock_t t = 0, t0, t1;
524 cpc *c;
525 int n = 0;
526 const dig *g;
527 unsigned b, w;
528
529 if (a->f & AF_VERBOSE) {
530 print_guess(a->m, s);
531 fputs(": ", stdout);
532 fflush(stdout);
533 }
1c91cf29 534
0fd4f375
MW
535 c = cpc_new(a->m, CPCF_QUIET);
536 for (;;) {
537 t0 = clock();
538 g = cp_guess(c);
539 t1 = clock();
540 t += t1 - t0;
541 assert(g);
542 n++;
543 rate(r, g, &b, &w);
544 if (b == a->m->k)
545 break;
546 t0 = clock();
547 cp_update(c, g, b, w);
548 t1 = clock();
1c91cf29 549 t += t1 - t0;
0fd4f375
MW
550 }
551 a->t += t;
552 a->g += n;
553 while (DA_LEN(&a->gmap) <= n)
554 DA_PUSH(&a->gmap, 0);
555 DA(&a->gmap)[n]++;
556 rate_free(r);
557 cpc_free(c);
558
559 if (a->f & AF_VERBOSE)
c41ed91c 560 printf("%2u (%5.2fs)\n", n, t/(double)CLOCKS_PER_SEC);
0fd4f375
MW
561}
562
563static void dorunall(allstats *a, dig *s, unsigned i)
564{
565 dig j;
566
567 if (i >= a->m->k) {
568 dorunone(a, s);
569 a->n++;
570 } else {
571 for (j = 0; j < a->m->n; j++) {
572 s[i] = j;
573 dorunall(a, s, i + 1);
574 }
575 }
576}
577
578static void run_all(const mm *m)
579{
580 dig *s = xmalloc(m->k * sizeof(dig));
581 allstats a;
582 unsigned i;
583
584 a.m = m;
585 a.f = AF_VERBOSE;
586 DA_CREATE(&a.gmap);
587 a.n = 0;
588 a.g = 0;
589 a.t = 0;
590 dorunall(&a, s, 0);
591 xfree(s);
592
593 for (i = 1; i < DA_LEN(&a.gmap); i++)
594 printf("%2u guesses: %5u games\n", i, DA(&a.gmap)[i]);
595 printf("Average: %.4f (%.2fs)\n",
acc96aae 596 (double)a.g/a.n, a.t/(a.n * (double)CLOCKS_PER_SEC));
0fd4f375
MW
597}
598
7a869046
MW
599/*----- Main game logic ---------------------------------------------------*/
600
601static int play(const mm *m,
602 void (*ratefn)(void *rr, const dig *g,
603 unsigned *b, unsigned *w),
604 void *rr,
605 const dig *(*guessfn)(void *gg),
606 void (*updatefn)(void *gg, const dig *g,
607 unsigned b, unsigned w),
608 void *gg)
609{
610 unsigned b, w;
611 const dig *g;
612 unsigned i;
613
614 i = 0;
615 for (;;) {
616 i++;
617 g = guessfn(gg);
618 if (!g)
619 return (-1);
620 ratefn(rr, g, &b, &w);
621 if (b == m->k)
622 return (i);
623 updatefn(gg, g, b, w);
624 }
625}
626
627int main(int argc, char *argv[])
628{
629 unsigned h = 0;
630 void *rr = 0;
631 void (*ratefn)(void *rr, const dig *g, unsigned *b, unsigned *w) = 0;
632 mm m;
633 int n;
634
635 ego(argv[0]);
636 for (;;) {
637 static struct option opt[] = {
638 { "computer", 0, 0, 'C' },
639 { "human", 0, 0, 'H' },
640 { "solver", 0, 0, 'S' },
0fd4f375 641 { "all", 0, 0, 'a' },
7a869046
MW
642 { 0, 0, 0, 0 }
643 };
0fd4f375 644 int i = mdwopt(argc, argv, "CHSa", opt, 0, 0, 0);
7a869046
MW
645 if (i < 0)
646 break;
647 switch (i) {
648 case 'C': h = 0; break;
649 case 'H': h = 1; break;
650 case 'S': h = 2; break;
0fd4f375 651 case 'a': h = 99; break;
7a869046
MW
652 default:
653 exit(1);
654 }
655 }
656 if (argc - optind == 0) {
657 m.k = 4;
658 m.n = 6;
659 } else if (argc - optind < 2)
660 die(1, "bad parameters");
661 else {
662 m.k = atoi(argv[optind++]);
663 m.n = atoi(argv[optind++]);
664 if (argc - optind >= m.k) {
665 dig *s = xmalloc(m.k * sizeof(dig));
666 int i;
667 for (i = 0; i < m.k; i++)
668 s[i] = atoi(argv[optind++]);
669 rr = rate_new(&m, s);
670 ratefn = cp_rate;
671 xfree(s);
672 }
673 if (argc != optind)
674 die(1, "bad parameters");
675 }
676
677 switch (h) {
678 case 1: {
679 hpc *hh = hpc_new(&m);
680 if (!rr) {
681 dig *s = xmalloc(m.k * sizeof(dig));
682 int i;
683 srand(time(0));
684 for (i = 0; i < m.k; i++)
685 s[i] = rand() % m.n;
686 rr = rate_new(&m, s);
687 ratefn = cp_rate;
688 xfree(s);
689 }
690 n = play(&m, ratefn, rr, hp_guess, hp_update, hh);
691 hpc_free(hh);
692 } break;
693 case 0: {
9ea0417a 694 cpc *cc = cpc_new(&m, 0);
1c91cf29 695 if (rr)
7a869046
MW
696 n = play(&m, ratefn, rr, cp_guess, cp_update, cc);
697 else
698 n = play(&m, hp_rate, &m, cp_guess, cp_update, cc);
699 cpc_free(cc);
700 } break;
701 case 2: {
702 spc *ss = spc_new(&m);
703 n = play(&m, hp_rate, &m, sp_guess, sp_update, ss);
704 spc_free(ss);
705 } break;
0fd4f375
MW
706 case 99:
707 run_all(&m);
708 return (0);
709 break;
7a869046
MW
710 default:
711 abort();
712 }
713 if (n > 0)
714 printf("Solved in %d guesses\n", n);
715 else
716 die(1, "gave up");
717 return (0);
718}
719
720/*----- That's all, folks -------------------------------------------------*/