Overhaul `math' representation machinery.
[u/mdw/catacomb] / math / mpgen
CommitLineData
1c3d4cf5
MW
1#! @PYTHON@
2###
3### Generate multiprecision integer representations
4###
5### (c) 2013 Straylight/Edgeware
6###
7
8###----- Licensing notice ---------------------------------------------------
9###
10### This file is part of Catacomb.
11###
12### Catacomb is free software; you can redistribute it and/or modify
13### it under the terms of the GNU Library General Public License as
14### published by the Free Software Foundation; either version 2 of the
15### License, or (at your option) any later version.
16###
17### Catacomb is distributed in the hope that it will be useful,
18### but WITHOUT ANY WARRANTY; without even the implied warranty of
19### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20### GNU Library General Public License for more details.
21###
22### You should have received a copy of the GNU Library General Public
23### License along with Catacomb; if not, write to the Free
24### Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25### MA 02111-1307, USA.
26
27from __future__ import with_statement
28
29import re as RX
30import optparse as OP
31import types as TY
32
33from sys import stdout
34
35###--------------------------------------------------------------------------
36### Random utilities.
37
38def write_header(mode, name):
39 stdout.write("""\
40/* -*-c-*- GENERATED by mpgen (%s)
41 *
42 * %s
43 */
44
45""" % (mode, name))
46
47def write_banner(text):
48 stdout.write("/*----- %s %s*/\n" % (text, '-' * (66 - len(text))))
49
50class struct (object): pass
51
52R_IDBAD = RX.compile('[^0-9A-Za-z]')
53def fix_name(name): return R_IDBAD.sub('_', name)
54
55###--------------------------------------------------------------------------
56### Determining the appropriate types.
57
58TYPEMAP = {}
59
60class IntClass (type):
61 def __new__(cls, name, supers, dict):
62 c = type.__new__(cls, name, supers, dict)
63 try: TYPEMAP[c.tag] = c
64 except AttributeError: pass
65 return c
66
67class BasicIntType (object):
68 __metaclass__ = IntClass
69 preamble = ''
70 typedef_prefix = ''
71 literalfmt = '%su'
72 def __init__(me, bits, rank):
73 me.bits = bits
74 me.rank = rank
75 me.litwd = len(me.literal(0))
76 def literal(me, value, fmt = None):
77 if fmt is None: fmt = '0x%0' + str((me.bits + 3)//4) + 'x'
78 return me.literalfmt % (fmt % value)
79
80class UnsignedCharType (BasicIntType):
81 tag = 'uchar'
82 name = 'unsigned char'
83
84class UnsignedShortType (BasicIntType):
85 tag = 'ushort'
86 name = 'unsigned short'
87
88class UnsignedIntType (BasicIntType):
89 tag = 'uint'
90 name = 'unsigned int'
91
92class UnsignedLongType (BasicIntType):
93 tag = 'ulong'
94 name = 'unsigned long'
95 literalfmt = '%sul'
96
97class UnsignedLongLongType (BasicIntType):
98 tag = 'ullong'
99 name = 'unsigned long long'
100 preamble = """
101#if __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 91)
102# define CATACOMB_GCC_EXTENSION __extension__
103#else
104# define CATACOMB_GCC_EXTENSION
105#endif
106"""
107 typedef_prefix = 'CATACOMB_GCC_EXTENSION '
108 literalfmt = 'CATACOMB_GCC_EXTENSION %sull'
109
110class UIntMaxType (BasicIntType):
111 tag = 'uintmax'
112 name = 'uintmax_t'
113 preamble = "\n#include <stdint.h>\n"
114
115class TypeChoice (object):
116 def __init__(me, tifile):
117
118 ## Load the captured type information.
119 me.ti = TY.ModuleType('typeinfo')
120 execfile(opts.typeinfo, me.ti.__dict__)
121
122 ## Build a map of the available types.
123 tymap = {}
124 byrank = []
125 for tag, bits in me.ti.TYPEINFO:
126 rank = len(byrank)
127 tymap[tag] = rank
128 byrank.append(TYPEMAP[tag](bits, rank))
129
130 ## First pass: determine a suitable word size. The criteria are (a)
131 ## there exists another type at least twice as long (so that we can do a
132 ## single x single -> double multiplication), and (b) operations on a
133 ## word are efficient (so we'd prefer a plain machine word). We'll start
134 ## at `int' and work down. Maybe this won't work: there's a plan B.
135 mpwbits = 0
136 i = tymap['uint']
137 while not mpwbits and i >= 0:
138 ibits = byrank[i].bits
139 for j in xrange(i + 1, len(byrank)):
140 if byrank[j].bits >= 2*ibits:
141 mpwbits = ibits
142 break
143
144 ## If that didn't work, then we'll start with the largest type available
145 ## and go with half its size.
146 if not mpwbits:
147 mpwbits = byrank[-1].bits//2
148
149 ## Make sure we've not ended up somewhere really silly.
150 if mpwbits < 16:
151 raise Exception, "`mpw' type is too small: your C environment is weird"
152
153 ## Now figure out suitable types for `mpw' and `mpd'.
154 def find_type(bits, what):
155 for ty in byrank:
156 if ty.bits >= bits: return ty
157 raise Exception, \
158 "failed to find suitable %d-bit type, for %s" % (bits, what)
159
160 ## Store our decisions.
161 me.mpwbits = mpwbits
162 me.mpw = find_type(mpwbits, 'mpw')
163 me.mpd = find_type(mpwbits*2, 'mpd')
164
165###--------------------------------------------------------------------------
166### Outputting constant multiprecision integers.
167
168MARGIN = 72
169
170def write_preamble():
171 stdout.write("""
172#include <mLib/macros.h>
173#define MP_(name, flags) \\
174 { (/*unconst*/ mpw *)name##__mpw, \\
175 (/*unconst*/ mpw *)name##__mpw + N(name##__mpw), \\
176 N(name##__mpw), 0, MP_CONST | flags, 0 }
177#define ZERO_MP { 0, 0, 0, 0, MP_CONST, 0 }
178#define POS_MP(name) MP_(name, 0)
179#define NEG_MP(name) MP_(name, MP_NEG)
180""")
181
182def write_limbs(name, x):
183 if not x: return
184 stdout.write("\nstatic const mpw %s__mpw[] = {" % name)
185 sep = ''
186 pos = MARGIN
187 if x < 0: x = -x
188 mask = (1 << TC.mpwbits) - 1
189
190 while x > 0:
191 w, x = x & mask, x >> TC.mpwbits
192 f = TC.mpw.literal(w)
193 if pos + 2 + len(f) <= MARGIN:
194 stdout.write(sep + ' ' + f)
195 else:
196 pos = 2
197 stdout.write(sep + '\n ' + f)
198 pos += len(f) + 2
199 sep = ','
200
201 stdout.write("\n};\n")
202
203def mp_body(name, x):
204 return "%s_MP(%s)" % (x >= 0 and "POS" or "NEG", name)
205
206###--------------------------------------------------------------------------
207### Mode definition machinery.
208
209MODEMAP = {}
210
211def defmode(func):
212 name = func.func_name
213 if name.startswith('m_'): name = name[2:]
214 MODEMAP[name] = func
215 return func
216
217###--------------------------------------------------------------------------
218### The basic types header.
219
220@defmode
221def m_mptypes():
222 write_header("mptypes", "mptypes.h")
223 stdout.write("""\
224#ifndef CATACOMB_MPTYPES_H
225#define CATACOMB_MPTYPES_H
226""")
227
228 have = set([TC.mpw, TC.mpd])
229 for t in have:
230 stdout.write(t.preamble)
231
232 for label, t, bits in [('mpw', TC.mpw, TC.mpwbits),
233 ('mpd', TC.mpd, TC.mpwbits*2)]:
234 LABEL = label.upper()
235 stdout.write("\n%stypedef %s %s;\n" % (t.typedef_prefix, t.name, label))
236 stdout.write("#define %s_BITS %d\n" % (LABEL, bits))
237 i = 1
238 while 2*i < bits: i *= 2
239 stdout.write("#define %s_P2 %d\n" % (LABEL, i))
240 stdout.write("#define %s_MAX %s\n" % (LABEL,
241 t.literal((1 << bits) - 1, "%d")))
242
243 stdout.write("\n#endif\n")
244
245###--------------------------------------------------------------------------
246### Constant tables.
247
248@defmode
249def m_mplimits_c():
250 write_header("mplimits_c", "mplimits.c")
251 stdout.write('#include "mplimits.h"\n')
252 write_preamble()
253 seen = {}
254 v = []
255 def write(x):
256 if not x or x in seen: return
257 seen[x] = 1
258 write_limbs('limits_%d' % len(v), x)
259 v.append(x)
260 for tag, lo, hi in TC.ti.LIMITS:
261 write(lo)
262 write(hi)
263
264 stdout.write("\nmp mp_limits[] = {")
265 i = 0
266 sep = "\n "
267 for x in v:
268 stdout.write("%s%s_MP(limits_%d)\n" % (sep, x < 0 and "NEG" or "POS", i))
269 i += 1
270 sep = ",\n "
271 stdout.write("\n};\n");
272
273@defmode
274def m_mplimits_h():
275 write_header("mplimits_h", "mplimits.h")
276 stdout.write("""\
277#ifndef CATACOMB_MPLIMITS_H
278#define CATACOMB_MPLIMITS_H
279
280#ifndef CATACOMB_MP_H
281# include "mp.h"
282#endif
283
284extern mp mp_limits[];
285
286""")
287
288 seen = { 0: "MP_ZERO" }
289 slot = [0]
290 def find(x):
291 try:
292 r = seen[x]
293 except KeyError:
294 r = seen[x] = '(&mp_limits[%d])' % slot[0]
295 slot[0] += 1
296 return r
297 for tag, lo, hi in TC.ti.LIMITS:
298 stdout.write("#define MP_%s_MIN %s\n" % (tag, find(lo)))
299 stdout.write("#define MP_%s_MAX %s\n" % (tag, find(hi)))
300
301 stdout.write("\n#endif\n")
302
303###--------------------------------------------------------------------------
304### Group tables.
305
306class GroupTableClass (type):
307 def __new__(cls, name, supers, dict):
308 c = type.__new__(cls, name, supers, dict)
309 try: mode = c.mode
310 except AttributeError: pass
311 else: MODEMAP[c.mode] = c.run
312 return c
313
314class GroupTable (object):
315 __metaclass__ = GroupTableClass
316 keyword = 'group'
317 slots = []
318 def __init__(me):
319 me.st = st = struct()
320 st.nextmp = 0
321 st.mpmap = { None: 'NO_MP', 0: 'ZERO_MP' }
322 st.d = {}
323 me.st.name = None
324 me._names = []
325 me._defs = set()
326 me._slotmap = dict([(s.name, s) for s in me.slots])
327 me._headslots = [s for s in me.slots if s.headline]
328 def _flush(me):
329 if me.st.name is None: return
330 stdout.write("/* --- %s --- */\n" % me.st.name)
331 for s in me.slots: s.setup(me.st)
332 stdout.write("\nstatic %s c_%s = {" % (me.data_t, fix_name(me.st.name)))
333 sep = "\n "
334 for s in me.slots:
335 stdout.write(sep)
336 s.write(me.st)
337 sep = ",\n "
338 stdout.write("\n};\n\n")
339 me.st.d = {}
340 me.st.name = None
341 @classmethod
342 def run(cls, input):
343 me = cls()
344 write_header(me.mode, me.filename)
345 stdout.write('#include "%s"\n' % me.header)
346 write_preamble()
347 stdout.write("#define NO_MP { 0, 0, 0, 0, 0, 0 }\n\n")
348 write_banner("Group data")
349 stdout.write('\n')
350 with open(input) as file:
351 for line in file:
352 ff = line.split()
353 if not ff or ff[0].startswith('#'): continue
354 if ff[0] == 'alias':
355 if len(ff) != 3: raise Exception, "wrong number of alias arguments"
356 me._flush()
357 me._names.append((ff[1], ff[2]))
358 elif ff[0] == me.keyword:
359 if len(ff) < 2 or len(ff) > 2 + len(me._headslots):
360 raise Exception, "bad number of headline arguments"
361 me._flush()
362 me.st.name = name = ff[1]
363 me._defs.add(name)
364 me._names.append((name, name))
365 for f, s in zip(ff[2:], me._headslots): s.set(me.st, f)
366 elif ff[0] in me._slotmap:
367 if len(ff) != 2:
368 raise Exception, "bad number of values for slot `%s'" % ff[0]
369 me._slotmap[ff[0]].set(me.st, ff[1])
370 else:
371 raise Exception, "unknown keyword `%s'" % ff[0]
372 me._flush()
373 write_banner("Main table")
374 stdout.write("\nconst %s %s[] = {\n" % (me.entry_t, me.tabname))
375 for a, n in me._names:
376 if n not in me._defs:
377 raise Exception, "alias `%s' refers to unknown group `%s'" % (a, n)
378 stdout.write(' { "%s", &c_%s },\n' % (a, fix_name(n)))
379 stdout.write(" { 0, 0 }\n};\n\n")
380 write_banner("That's all, folks")
381
382class BaseSlot (object):
383 def __init__(me, name, headline = False, omitp = None, allowp = None):
384 me.name = name
385 me.headline = headline
386 me._omitp = None
387 me._allowp = None
388 def set(me, st, value):
389 if me._allowp and not me._allowp(st, value):
390 raise Exception, "slot `%s' not allowed here" % me.name
391 st.d[me] = value
392 def setup(me, st):
393 if me not in st.d and (not me._omitp or not me._omitp(st)):
394 raise Exception, "missing slot `%s'" % me.name
395
396class EnumSlot (BaseSlot):
397 def __init__(me, name, prefix, values, **kw):
398 super(EnumSlot, me).__init__(name, **kw)
399 me._values = set(values)
400 me._prefix = prefix
401 def set(me, st, value):
402 if value not in me._values:
403 raise Exception, "invalid %s value `%s'" % (me.name, value)
404 super(EnumSlot, me).set(st, value)
405 def write(me, st):
406 try: stdout.write('%s_%s' % (me._prefix, st.d[me].upper()))
407 except KeyError: stdout.write('0')
408
409class MPSlot (BaseSlot):
410 def setup(me, st):
411 v = st.d.get(me)
412 if v not in st.mpmap:
413 write_limbs('v%d' % st.nextmp, v)
414 st.mpmap[v] = mp_body('v%d' % st.nextmp, v)
415 st.nextmp += 1
416 def set(me, st, value):
417 super(MPSlot, me).set(st, long(value, 0))
418 def write(me, st):
419 stdout.write(st.mpmap[st.d.get(me)])
420
421class BinaryGroupTable (GroupTable):
422 mode = 'bintab'
423 filename = 'bintab.c'
424 header = 'bintab.h'
425 data_t = 'bindata'
426 entry_t = 'binentry'
427 tabname = 'bintab'
428 slots = [MPSlot('p'), MPSlot('q'), MPSlot('g')]
429
430class EllipticCurveTable (GroupTable):
431 mode = 'ectab'
432 filename = 'ectab.c'
433 header = 'ectab.h'
434 keyword = 'curve'
435 data_t = 'ecdata'
436 entry_t = 'ecentry'
437 tabname = 'ectab'
438 slots = [EnumSlot('type', 'FTAG',
439 ['prime', 'niceprime', 'binpoly', 'binnorm'],
440 headline = True),
441 MPSlot('p'),
442 MPSlot('beta',
443 allowp = lambda st, _: st.d['type'] == 'binnorm',
444 omitp = lambda st: st.d['type'] != 'binnorm'),
445 MPSlot('a'), MPSlot('b'), MPSlot('r'), MPSlot('h'),
446 MPSlot('gx'), MPSlot('gy')]
447
448class PrimeGroupTable (GroupTable):
449 mode = 'ptab'
450 filename = 'ptab.c'
451 header = 'ptab.h'
452 data_t = 'pdata'
453 entry_t = 'pentry'
454 tabname = 'ptab'
455 slots = [MPSlot('p'), MPSlot('q'), MPSlot('g')]
456
457###--------------------------------------------------------------------------
458### Main program.
459
460op = OP.OptionParser(
461 description = 'Generate multiprecision integer representations',
462 usage = 'usage: %prog [-t TYPEINFO] MODE [ARGS ...]',
463 version = 'Catacomb, version @VERSION@')
464for shortopt, longopt, kw in [
465 ('-t', '--typeinfo', dict(
466 action = 'store', metavar = 'PATH', dest = 'typeinfo',
467 help = 'alternative typeinfo file'))]:
468 op.add_option(shortopt, longopt, **kw)
469op.set_defaults(typeinfo = './typeinfo.py')
470opts, args = op.parse_args()
471
472if len(args) < 1: op.error('missing MODE')
473mode = args[0]
474
475TC = TypeChoice(opts.typeinfo)
476
477try: modefunc = MODEMAP[mode]
478except KeyError: op.error("unknown mode `%s'" % mode)
479modefunc(*args[1:])
480
481###----- That's all, folks --------------------------------------------------