Add commentary and licence notices.
[rhodes] / rhodes
1 #! /usr/bin/python
2 ### -*-python-*-
3 ###
4 ### Calculate discrete logs in groups
5 ###
6 ### (c) 2017 Mark Wooding
7 ###
8
9 ###----- Licensing notice ---------------------------------------------------
10 ###
11 ### This file is part of Rhodes, a distributed discrete-log finder.
12 ###
13 ### Rhodes is free software; you can redistribute it and/or modify
14 ### it under the terms of the GNU General Public License as published by
15 ### the Free Software Foundation; either version 2 of the License, or
16 ### (at your option) any later version.
17 ###
18 ### Rhodes is distributed in the hope that it will be useful,
19 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
20 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 ### GNU General Public License for more details.
22 ###
23 ### You should have received a copy of the GNU General Public License
24 ### along with Rhodes; if not, write to the Free Software Foundation,
25 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
26
27 from sys import argv, stdout, stderr, exit
28 import errno as E
29 import fcntl as F
30 import os as OS
31 import subprocess as S
32 import select as SEL
33 import signal as SIG
34
35 import catacomb as C
36 import sqlite3 as SQL
37
38 ###--------------------------------------------------------------------------
39 ### Miscellaneous utilities.
40
41 class ExpectedError (Exception):
42 pass
43
44 ###--------------------------------------------------------------------------
45 ### Database handling.
46
47 CONNINIT_SQL = """
48 PRAGMA foreign_keys = on;
49 """
50
51 SETUP_SQL = """
52 PRAGMA journal_mode = wal;
53
54 CREATE TABLE top
55 (kind TEXT NOT NULL, -- `gf2x'
56 groupdesc TEXT NOT NULL,
57 g TEXT NOT NULL,
58 x TEXT NOT NULL,
59 m TEXT NOT NULL, -- g^m = 1
60 n TEXT DEFAULT NULL); -- g^n = x
61
62 CREATE TABLE progress
63 (p TEXT PRIMARY KEY NOT NULL, -- p|m, p prime
64 e INT NOT NULL, -- e = v_p(m)
65 k INT NOT NULL DEFAULT(0), -- 0 <= k <= e
66 n TEXT NOT NULL DEFAULT(0), -- (g^{m/p^k})^n = x^{m/p^k}
67 dpbits INT NOT NULL); -- 0 for sequential
68 CREATE UNIQUE INDEX progress_by_p_k ON progress (p, k);
69
70 CREATE TABLE workers
71 (pid INT PRIMARY KEY NOT NULL,
72 p TEXT NOT NULL,
73 k INT NOT NULL,
74 FOREIGN KEY (p, k) REFERENCES progress (p, k));
75 CREATE INDEX workers_by_p ON workers (p, k);
76
77 CREATE TABLE points
78 (p TEXT NOT NULL,
79 k INT NOT NULL,
80 z TEXT NOT NULL, -- g^a x^b = z
81 a TEXT NOT NULL,
82 b TEXT NOT NULL,
83 PRIMARY KEY (p, k, z),
84 FOREIGN KEY (p, k) REFERENCES progress (p, k));
85 """
86
87 def connect_db(dir):
88 db = SQL.connect(OS.path.join(dir, 'db'))
89 db.text_factory = str
90 c = db.cursor()
91 c.executescript(CONNINIT_SQL)
92 return db
93
94 ###--------------------------------------------------------------------------
95 ### Group support.
96
97 GROUPMAP = {}
98
99 class GroupClass (type):
100 def __new__(cls, name, supers, dict):
101 ty = super(GroupClass, cls).__new__(cls, name, supers, dict)
102 try: name = ty.NAME
103 except AttributeError: pass
104 else: GROUPMAP[name] = ty
105 return ty
106
107 class BaseGroup (object):
108 __metaclass__ = GroupClass
109 def __init__(me, desc):
110 me.desc = desc
111 def div(me, x, y):
112 return me.mul(x, me.inv(y))
113
114 class BinaryFieldUnitGroup (BaseGroup):
115 NAME = 'gf2x'
116 def __init__(me, desc):
117 super(BinaryFieldUnitGroup, me).__init__(desc)
118 p = C.GF(desc)
119 if not p.irreduciblep(): raise ExpectedError, 'not irreducible'
120 me._k = C.BinPolyField(p)
121 me.order = me._k.q - 1
122 def elt(me, x):
123 return me._k(C.GF(x))
124 def pow(me, x, n):
125 return x**n
126 def mul(me, x, y):
127 return x*y
128 def inv(me, x):
129 return x.inv()
130 def idp(me, x):
131 return x == me._k.one
132 def eq(me, x, y):
133 return x == y
134 def str(me, x):
135 return str(x)
136
137 def getgroup(kind, desc): return GROUPMAP[kind](desc)
138
139 ###--------------------------------------------------------------------------
140 ### Number-theoretic utilities.
141
142 def factor(n):
143 ff = []
144 proc = S.Popen(['./factor', str(n)], stdout = S.PIPE)
145 for line in proc.stdout:
146 pstr, estr = line.split()
147 ff.append((C.MP(pstr), int(estr)))
148 rc = proc.wait()
149 if rc: raise ExpectedError, 'factor failed: rc = %d' % rc
150 return ff
151
152 ###--------------------------------------------------------------------------
153 ### Command dispatch.
154
155 CMDMAP = {}
156
157 def defcommand(f, name = None):
158 if isinstance(f, basestring):
159 return lambda g: defcommand(g, f)
160 else:
161 if name is None: name = f.__name__
162 CMDMAP[name] = f
163 return f
164
165 ###--------------------------------------------------------------------------
166 ### Job status utilities.
167
168 def get_top(db):
169 c = db.cursor()
170 c.execute("""SELECT kind, groupdesc, g, x, m, n FROM top""")
171 kind, groupdesc, gstr, xstr, mstr, nstr = c.fetchone()
172 G = getgroup(kind, groupdesc)
173 g, x, m = G.elt(gstr), G.elt(xstr), C.MP(mstr)
174 n = nstr is not None and C.MP(nstr) or None
175 return G, g, x, m, n
176
177 def get_job(db):
178 c = db.cursor()
179 c.execute("""SELECT p.p, p.e, p.k, p.n, p.dpbits
180 FROM progress AS p LEFT OUTER JOIN workers AS w
181 ON p.p = w.p and p.k = w.k
182 WHERE p.k < p.e AND (p.dpbits > 0 OR w.pid IS NULL)
183 LIMIT 1""")
184 row = c.fetchone()
185 if row is None: return None, None, None, None, None
186 else:
187 pstr, e, k, nstr, dpbits = row
188 p, n = C.MP(pstr), C.MP(nstr)
189 return p, e, k, n, dpbits
190
191 def maybe_cleanup_worker(dir, db, pid):
192 c = db.cursor()
193 f = OS.path.join(dir, 'lk.%d' % pid)
194 state = 'LIVE'
195 try: fd = OS.open(f, OS.O_WRONLY)
196 except OSError, err:
197 if err.errno != E.ENOENT: raise ExpectedError, 'open lockfile: %s' % err
198 state = 'STALE'
199 else:
200 try: F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
201 except IOError, err:
202 if err.errno != E.EAGAIN: raise ExpectedError, 'check lock: %s' % err
203 else:
204 state = 'STALE'
205 if state == 'STALE':
206 try: OS.unlink(f)
207 except OSError: pass
208 c.execute("""DELETE FROM workers WHERE pid = ?""", (pid,))
209
210 def maybe_kill_worker(dir, pid):
211 f = OS.path.join(dir, 'lk.%d' % pid)
212 try: fd = OS.open(f, OS.O_RDWR)
213 except OSError, err:
214 if err.errno != E.ENOENT: raise ExpectedError, 'open lockfile: %s' % err
215 return
216 try: F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
217 except IOError, err:
218 if err.errno != E.EAGAIN: raise ExpectedError, 'check lock: %s' % err
219 else: return
220 OS.kill(pid, SIG.SIGTERM)
221 try: OS.unlink(f)
222 except OSError: pass
223
224 ###--------------------------------------------------------------------------
225 ### Setup.
226
227 @defcommand
228 def setup(dir, kind, groupdesc, gstr, xstr):
229
230 ## Get the group. This will also figure out the group order.
231 G = getgroup(kind, groupdesc)
232
233 ## Figure out the generator order.
234 g = G.elt(gstr)
235 x = G.elt(xstr)
236 ff = []
237 m = G.order
238 for p, e in factor(m):
239 ee = 0
240 for i in xrange(e):
241 mm = m/p
242 t = G.pow(g, mm)
243 if not G.idp(t): break
244 ee += 1; m = mm
245 if ee < e: ff.append((p, e - ee))
246
247 ## Check that x at least has the right order. This check is imperfect.
248 if not G.idp(G.pow(x, m)): raise ValueError, 'x not in <g>'
249
250 ## Prepare the directory.
251 try: OS.mkdir(dir)
252 except OSError, err: raise ExpectedError, 'mkdir: %s' % err
253
254 ## Prepare the database.
255 db = connect_db(dir)
256 c = db.cursor()
257 c.executescript(SETUP_SQL)
258
259 ## Populate the general information.
260 with db:
261 c.execute("""INSERT INTO top (kind, groupdesc, g, x, m)
262 VALUES (?, ?, ?, ?, ?)""",
263 (kind, groupdesc, G.str(g), G.str(x), str(m)))
264 for p, e in ff:
265 if p.nbits <= 48: dpbits = 0
266 else: dpbits = p.nbits*2/5
267 c.execute("""INSERT INTO progress (p, e, dpbits) VALUES (?, ?, ?)""",
268 (str(p), e, dpbits))
269
270 ###--------------------------------------------------------------------------
271 ### Check.
272
273 @defcommand
274 def check(dir):
275 rc = [0]
276 def bad(msg):
277 print >>stderr, '%s: %s' % (PROG, msg)
278 rc[0] = 3
279 db = connect_db(dir)
280 c = db.cursor()
281 G, g, x, m, n = get_top(db)
282 print '## group: %s %s' % (G.NAME, G.desc)
283 print '## g = %s' % G.str(g)
284 print '## x = %s' % G.str(x)
285
286 if not G.idp(G.pow(g, m)):
287 bad('bad generator/order: %s^%d /= 1' % (G.str(g), m))
288 if not G.idp(G.pow(x, m)):
289 bad('x not in group: %s^%d /= 1' % (G.str(x), m))
290
291 ## Clear away old workers that aren't doing anything useful any more.
292 ## For each worker pid, check that its lockfile is still locked; if
293 ## not, it's finished and can be disposed of.
294 c.execute("""SELECT pid FROM workers""")
295 for pid, in c:
296 maybe_cleanup_worker(dir, db, pid)
297 for f in OS.listdir(dir):
298 if f.startswith('lk.'):
299 pid = int(f[3:])
300 maybe_cleanup_worker(dir, db, pid)
301
302 c.execute("""SELECT p.p, p.e, p.k, p.n, p.dpbits, COUNT(d.z)
303 FROM progress AS p LEFT OUTER JOIN points AS d
304 ON p.p = d.p AND p.k = d.k
305 GROUP BY p.p, p.k
306 ORDER BY LENGTH(p.p), p.p""")
307 mm = 1
308 for pstr, e, k, nnstr, dpbits, ndp in c:
309 p, nn = C.MP(pstr), C.MP(nnstr)
310 q = p**e
311 if m%q:
312 bad('incorrect order factorization: %d^%d /| %d' % (p, e, m))
313 mm *= q
314 if G.idp(G.pow(g, m/p)):
315 bad('bad generator/order: %s^{%d/%d} = 1' ^ (G.str(g), m, p))
316 r = m/p**k
317 h = G.pow(g, r*nn)
318 y = G.pow(x, r)
319 if not G.eq(h, y):
320 bad('bad partial log: (%s^{%d/%d^%d})^%d = %s /= %s = %s^{%d/%d^%d}' %
321 (G.str(g), m, p, k, nn, G.str(h), G.str(y), G.str(x), m, p, k))
322 if not dpbits or k == e: dpinfo = ''
323 else: dpinfo = ' [%d: %d]' % (dpbits, ndp)
324 print '## %d: %d/%d%s' % (p, k, e, dpinfo)
325 if mm != m:
326 bad('incomplete factorization: %d /= %d' % (mm, m))
327
328 if n is not None:
329 xx = G.pow(g, n)
330 if not G.eq(xx, x):
331 bad('incorrect log: %s^%d = %s /= %s' %
332 (G.str(g), n, G.str(xx), G.str(x)))
333 print '## DONE: %d' % n
334
335 exit(rc[0])
336
337 ###--------------------------------------------------------------------------
338 ### Done.
339
340 @defcommand
341 def done(dir):
342 db = connect_db(dir)
343 c = db.cursor()
344 G, g, x, m, n = get_top(db)
345 if n is not None:
346 print '## DONE: %d' % n
347 exit(0)
348 p, e, k, n, dpbits = get_job(db)
349 if p is None: exit(2)
350 else: exit(1)
351
352 ###--------------------------------------------------------------------------
353 ### Step.
354
355 @defcommand
356 def step(dir, cmd, *args):
357
358 ## Open the database.
359 db = connect_db(dir)
360 c = db.cursor()
361 ##db.isolation_level = 'EXCLUSIVE'
362
363 ## Prepare our lockfile names.
364 mypid = OS.getpid()
365 nlk = OS.path.join(dir, 'nlk.%d' % mypid)
366 lk = OS.path.join(dir, 'lk.%d' % mypid)
367
368 ## Overall exception handling...
369 try:
370
371 ## Find out what needs doing and start doing it. For this, we open a
372 ## transaction.
373 with db:
374 G, g, x, m, n = get_top(db)
375 if n is not None: raise ExpectedError, 'job done'
376
377 ## Find something to do. Either a job that's small enough for us to
378 ## take on alone, and that nobody else has picked up yet, or one that
379 ## everyone's pitching in on.
380 p, e, k, n, dpbits = get_job(db)
381 if p is None: raise ExpectedError, 'no work to do'
382
383 ## Figure out what needs doing. Let q = p^e, h = g^{m/q}, y = x^{m/q}.
384 ## Currently we have n_0 where
385 ##
386 ## h^{p^{e-k} n_0} = y^{p^{e-k}}
387 ##
388 ## Suppose n == n_0 + p^k n' (mod p^{k+1}). Then p^k n' == n - n_0
389 ## (mod p^{k+1}).
390 ##
391 ## (h^{p^{e-1}})^{n'} = (g^{m/p})^{n'}
392 ## = (y/h^{n_0})^{p^{e-k-1}}
393 ##
394 ## so this is the next discrete log to solve.
395 q = p**e
396 o = m/q
397 h, y = G.pow(g, o), G.pow(x, o)
398 hh = G.pow(h, p**(e-1))
399 yy = G.pow(G.div(y, G.pow(h, n)), p**(e-k-1))
400
401 ## Take out a lockfile.
402 fd = OS.open(nlk, OS.O_WRONLY | OS.O_CREAT, 0700)
403 F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
404 OS.rename(nlk, lk)
405
406 ## Record that we're working in the database. This completes our
407 ## initial transaction.
408 c.execute("""INSERT INTO workers (pid, p, k) VALUES (?, ?, ?)""",
409 (mypid, str(p), k))
410
411 ## Before we get too stuck in, check for an easy case.
412 if G.idp(yy):
413 dpbits = 0 # no need for distinguished points
414 nn = 0; ni = 0
415 else:
416
417 ## There's nothing else for it. Start the job up.
418 proc = S.Popen([cmd] + list(args) +
419 [str(dpbits), G.NAME, G.desc,
420 G.str(hh), G.str(yy), str(p)],
421 stdin = S.PIPE, stdout = S.PIPE)
422 f_in, f_out = proc.stdin.fileno(), proc.stdout.fileno()
423
424 ## Now we must look after it until it starts up. Feed it stuff on stdin
425 ## periodically, so that we notice if our network connectivity is lost.
426 ## Collect its stdout.
427 for fd in [f_in, f_out]:
428 fl = F.fcntl(fd, F.F_GETFL)
429 F.fcntl(fd, F.F_SETFL, fl | OS.O_NONBLOCK)
430 done = False
431 out = ''
432 while not done:
433 rdy, wry, exy = SEL.select([f_out], [], [], 30.0)
434 if rdy:
435 while True:
436 try: b = OS.read(f_out, 4096)
437 except OSError, err:
438 if err.errno == E.EAGAIN: break
439 else: raise ExpectedError, 'read job: %s' % err
440 else:
441 if not len(b): done = True; break
442 else: out += b
443 if not done:
444 try: OS.write(f_in, '.')
445 except OSError, err: raise ExpectedError, 'write job: %s' % err
446 rc = proc.wait()
447 if rc: raise ExpectedError, 'job failed: rc = %d' % rc
448
449 ## Parse the answer. There are two cases.
450 if not dpbits:
451 nnstr, nistr = out.split()
452 nn, ni = C.MP(nnstr), int(nistr)
453 else:
454 astr, bstr, zstr, nistr = out.split()
455 a, b, z, ni = C.MP(astr), C.MP(bstr), G.elt(zstr), int(nistr)
456
457 ## We have an answer. Start a new transaction while we think about what
458 ## this means.
459 with db:
460
461 if dpbits:
462
463 ## Check that it's a correct point.
464 zz = G.mul(G.pow(hh, a), G.pow(yy, b))
465 if not G.eq(zz, z):
466 raise ExpectedError, \
467 'job incorrect distinguished point: %s^%d %s^%d = %s /= %s' % \
468 (hh, a, yy, b, zz, z)
469
470 ## Report this (partial) success.
471 print '## [%d, %d/%d: %d]: %d %d -> %s [%d]' % \
472 (p, k, e, dpbits, a, b, G.str(z), ni)
473
474 ## If it's already in the database then we have an answer to the
475 ## problem.
476 c.execute("""SELECT a, b FROM points
477 WHERE p = ? AND k = ? AND z = ?""",
478 (str(p), k, str(z)))
479 row = c.fetchone()
480 if row is None:
481 nn = None
482 c.execute("""INSERT INTO points (p, k, a, b, z)
483 VALUES (?, ?, ?, ?, ?)""",
484 (str(p), str(k), str(a), str(b), G.str(z)))
485 else:
486 aastr, bbstr = row
487 aa, bb = C.MP(aastr), C.MP(bbstr)
488 if not (b - bb)%p:
489 raise ExpectedError, 'duplicate point :-('
490
491 ## We win!
492 nn = ((a - aa)*p.modinv(bb - b))%p
493 c.execute("""SELECT COUNT(z) FROM points WHERE p = ? AND k = ?""",
494 (str(p), k))
495 ni, = c.fetchone()
496 print '## [%s, %d/%d: %d] collision %d %d -> %s <- %s %s [#%d]' % \
497 (p, k, e, dpbits, a, b, G.str(z), aa, bb, ni)
498
499 ## If we don't have a final answer then we're done.
500 if nn is None: return
501
502 ## Check that the log we've recovered is correct.
503 yyy = G.pow(hh, nn)
504 if not G.eq(yyy, yy):
505 raise ExpectedError, 'recovered incorrect log: %s^%d = %s /= %s' % \
506 (G.str(hh), nn, G.str(yyy), G.str(yy))
507
508 ## Update the log for this prime power.
509 n += nn*p**k
510 k += 1
511
512 ## Check that this is also correct.
513 yyy = G.pow(h, n*p**(e-k))
514 yy = G.pow(y, p**(e-k))
515 if not G.eq(yyy, yy):
516 raise ExpectedError, 'lifted incorrect log: %s^d = %s /= %s' % \
517 (G.str(h), n, G.str(yyy), G.str(yy))
518
519 ## Kill off the other jobs working on this component. If we crash now,
520 ## we lose a bunch of work. :-(
521 c.execute("""SELECT pid FROM workers WHERE p = ? AND k = ?""",
522 (str(p), k - 1))
523 for pid, in c:
524 if pid != mypid: maybe_kill_worker(dir, pid)
525 c.execute("""DELETE FROM workers WHERE p = ? AND k = ?""",
526 (str(p), k - 1))
527 c.execute("""DELETE FROM points WHERE p = ? AND k = ?""",
528 (str(p), k - 1))
529
530 ## Looks like we're good: update the progress table.
531 c.execute("""UPDATE progress SET k = ?, n = ? WHERE p = ?""",
532 (k, str(n), str(p)))
533 print '## [%d, %d/%d]: %d [%d]' % (p, k, e, n, ni)
534
535 ## Quick check: are we done now?
536 c.execute("""SELECT p FROM progress WHERE k < e
537 LIMIT 1""")
538 row = c.fetchone()
539 if row is None:
540
541 ## Wow. Time to stitch everything together.
542 c.execute("""SELECT p, e, n FROM progress""")
543 qq, nn = [], []
544 for pstr, e, nstr in c:
545 p, n = C.MP(pstr), C.MP(nstr)
546 qq.append(p**e)
547 nn.append(n)
548 if len(qq) == 1: n = nn[0]
549 else: n = C.MPCRT(qq).solve(nn)
550
551 ## One last check that this is the right answer.
552 xx = G.pow(g, n)
553 if not G.eq(x, xx):
554 raise ExpectedError, \
555 'calculated incorrect final log: %s^d = %s /= %s' \
556 (G.str(g), n, G.str(xx), G.str(x))
557
558 ## We're good.
559 c.execute("""UPDATE top SET n = ?""", (str(n),))
560 print '## DONE: %d' % n
561
562 finally:
563
564 ## Delete our lockfile.
565 for f in [nlk, lk]:
566 try: OS.unlink(f)
567 except OSError: pass
568
569 ## Unregister from the database.
570 with db:
571 c.execute("""DELETE FROM workers WHERE pid = ?""", (mypid,))
572
573 ###--------------------------------------------------------------------------
574 ### Top-level program.
575
576 PROG = argv[0]
577
578 try:
579 CMDMAP[argv[1]](*argv[2:])
580 except ExpectedError, err:
581 print >>stderr, '%s: %s' % (PROG, err.message)
582 exit(3)