X-Git-Url: https://git.distorted.org.uk/u/mdw/catacomb/blobdiff_plain/e5b61a8dec3586f96d25bd3ef454176526ff0f69..1c3d4cf54a0edd484c4405f5332d39bb17f1aee0:/math/mpgen diff --git a/math/mpgen b/math/mpgen new file mode 100644 index 0000000..4ed4f16 --- /dev/null +++ b/math/mpgen @@ -0,0 +1,481 @@ +#! @PYTHON@ +### +### Generate multiprecision integer representations +### +### (c) 2013 Straylight/Edgeware +### + +###----- Licensing notice --------------------------------------------------- +### +### This file is part of Catacomb. +### +### Catacomb is free software; you can redistribute it and/or modify +### it under the terms of the GNU Library General Public License as +### published by the Free Software Foundation; either version 2 of the +### License, or (at your option) any later version. +### +### Catacomb is distributed in the hope that it will be useful, +### but WITHOUT ANY WARRANTY; without even the implied warranty of +### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +### GNU Library General Public License for more details. +### +### You should have received a copy of the GNU Library General Public +### License along with Catacomb; if not, write to the Free +### Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, +### MA 02111-1307, USA. + +from __future__ import with_statement + +import re as RX +import optparse as OP +import types as TY + +from sys import stdout + +###-------------------------------------------------------------------------- +### Random utilities. + +def write_header(mode, name): + stdout.write("""\ +/* -*-c-*- GENERATED by mpgen (%s) + * + * %s + */ + +""" % (mode, name)) + +def write_banner(text): + stdout.write("/*----- %s %s*/\n" % (text, '-' * (66 - len(text)))) + +class struct (object): pass + +R_IDBAD = RX.compile('[^0-9A-Za-z]') +def fix_name(name): return R_IDBAD.sub('_', name) + +###-------------------------------------------------------------------------- +### Determining the appropriate types. + +TYPEMAP = {} + +class IntClass (type): + def __new__(cls, name, supers, dict): + c = type.__new__(cls, name, supers, dict) + try: TYPEMAP[c.tag] = c + except AttributeError: pass + return c + +class BasicIntType (object): + __metaclass__ = IntClass + preamble = '' + typedef_prefix = '' + literalfmt = '%su' + def __init__(me, bits, rank): + me.bits = bits + me.rank = rank + me.litwd = len(me.literal(0)) + def literal(me, value, fmt = None): + if fmt is None: fmt = '0x%0' + str((me.bits + 3)//4) + 'x' + return me.literalfmt % (fmt % value) + +class UnsignedCharType (BasicIntType): + tag = 'uchar' + name = 'unsigned char' + +class UnsignedShortType (BasicIntType): + tag = 'ushort' + name = 'unsigned short' + +class UnsignedIntType (BasicIntType): + tag = 'uint' + name = 'unsigned int' + +class UnsignedLongType (BasicIntType): + tag = 'ulong' + name = 'unsigned long' + literalfmt = '%sul' + +class UnsignedLongLongType (BasicIntType): + tag = 'ullong' + name = 'unsigned long long' + preamble = """ +#if __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 91) +# define CATACOMB_GCC_EXTENSION __extension__ +#else +# define CATACOMB_GCC_EXTENSION +#endif +""" + typedef_prefix = 'CATACOMB_GCC_EXTENSION ' + literalfmt = 'CATACOMB_GCC_EXTENSION %sull' + +class UIntMaxType (BasicIntType): + tag = 'uintmax' + name = 'uintmax_t' + preamble = "\n#include \n" + +class TypeChoice (object): + def __init__(me, tifile): + + ## Load the captured type information. + me.ti = TY.ModuleType('typeinfo') + execfile(opts.typeinfo, me.ti.__dict__) + + ## Build a map of the available types. + tymap = {} + byrank = [] + for tag, bits in me.ti.TYPEINFO: + rank = len(byrank) + tymap[tag] = rank + byrank.append(TYPEMAP[tag](bits, rank)) + + ## First pass: determine a suitable word size. The criteria are (a) + ## there exists another type at least twice as long (so that we can do a + ## single x single -> double multiplication), and (b) operations on a + ## word are efficient (so we'd prefer a plain machine word). We'll start + ## at `int' and work down. Maybe this won't work: there's a plan B. + mpwbits = 0 + i = tymap['uint'] + while not mpwbits and i >= 0: + ibits = byrank[i].bits + for j in xrange(i + 1, len(byrank)): + if byrank[j].bits >= 2*ibits: + mpwbits = ibits + break + + ## If that didn't work, then we'll start with the largest type available + ## and go with half its size. + if not mpwbits: + mpwbits = byrank[-1].bits//2 + + ## Make sure we've not ended up somewhere really silly. + if mpwbits < 16: + raise Exception, "`mpw' type is too small: your C environment is weird" + + ## Now figure out suitable types for `mpw' and `mpd'. + def find_type(bits, what): + for ty in byrank: + if ty.bits >= bits: return ty + raise Exception, \ + "failed to find suitable %d-bit type, for %s" % (bits, what) + + ## Store our decisions. + me.mpwbits = mpwbits + me.mpw = find_type(mpwbits, 'mpw') + me.mpd = find_type(mpwbits*2, 'mpd') + +###-------------------------------------------------------------------------- +### Outputting constant multiprecision integers. + +MARGIN = 72 + +def write_preamble(): + stdout.write(""" +#include +#define MP_(name, flags) \\ + { (/*unconst*/ mpw *)name##__mpw, \\ + (/*unconst*/ mpw *)name##__mpw + N(name##__mpw), \\ + N(name##__mpw), 0, MP_CONST | flags, 0 } +#define ZERO_MP { 0, 0, 0, 0, MP_CONST, 0 } +#define POS_MP(name) MP_(name, 0) +#define NEG_MP(name) MP_(name, MP_NEG) +""") + +def write_limbs(name, x): + if not x: return + stdout.write("\nstatic const mpw %s__mpw[] = {" % name) + sep = '' + pos = MARGIN + if x < 0: x = -x + mask = (1 << TC.mpwbits) - 1 + + while x > 0: + w, x = x & mask, x >> TC.mpwbits + f = TC.mpw.literal(w) + if pos + 2 + len(f) <= MARGIN: + stdout.write(sep + ' ' + f) + else: + pos = 2 + stdout.write(sep + '\n ' + f) + pos += len(f) + 2 + sep = ',' + + stdout.write("\n};\n") + +def mp_body(name, x): + return "%s_MP(%s)" % (x >= 0 and "POS" or "NEG", name) + +###-------------------------------------------------------------------------- +### Mode definition machinery. + +MODEMAP = {} + +def defmode(func): + name = func.func_name + if name.startswith('m_'): name = name[2:] + MODEMAP[name] = func + return func + +###-------------------------------------------------------------------------- +### The basic types header. + +@defmode +def m_mptypes(): + write_header("mptypes", "mptypes.h") + stdout.write("""\ +#ifndef CATACOMB_MPTYPES_H +#define CATACOMB_MPTYPES_H +""") + + have = set([TC.mpw, TC.mpd]) + for t in have: + stdout.write(t.preamble) + + for label, t, bits in [('mpw', TC.mpw, TC.mpwbits), + ('mpd', TC.mpd, TC.mpwbits*2)]: + LABEL = label.upper() + stdout.write("\n%stypedef %s %s;\n" % (t.typedef_prefix, t.name, label)) + stdout.write("#define %s_BITS %d\n" % (LABEL, bits)) + i = 1 + while 2*i < bits: i *= 2 + stdout.write("#define %s_P2 %d\n" % (LABEL, i)) + stdout.write("#define %s_MAX %s\n" % (LABEL, + t.literal((1 << bits) - 1, "%d"))) + + stdout.write("\n#endif\n") + +###-------------------------------------------------------------------------- +### Constant tables. + +@defmode +def m_mplimits_c(): + write_header("mplimits_c", "mplimits.c") + stdout.write('#include "mplimits.h"\n') + write_preamble() + seen = {} + v = [] + def write(x): + if not x or x in seen: return + seen[x] = 1 + write_limbs('limits_%d' % len(v), x) + v.append(x) + for tag, lo, hi in TC.ti.LIMITS: + write(lo) + write(hi) + + stdout.write("\nmp mp_limits[] = {") + i = 0 + sep = "\n " + for x in v: + stdout.write("%s%s_MP(limits_%d)\n" % (sep, x < 0 and "NEG" or "POS", i)) + i += 1 + sep = ",\n " + stdout.write("\n};\n"); + +@defmode +def m_mplimits_h(): + write_header("mplimits_h", "mplimits.h") + stdout.write("""\ +#ifndef CATACOMB_MPLIMITS_H +#define CATACOMB_MPLIMITS_H + +#ifndef CATACOMB_MP_H +# include "mp.h" +#endif + +extern mp mp_limits[]; + +""") + + seen = { 0: "MP_ZERO" } + slot = [0] + def find(x): + try: + r = seen[x] + except KeyError: + r = seen[x] = '(&mp_limits[%d])' % slot[0] + slot[0] += 1 + return r + for tag, lo, hi in TC.ti.LIMITS: + stdout.write("#define MP_%s_MIN %s\n" % (tag, find(lo))) + stdout.write("#define MP_%s_MAX %s\n" % (tag, find(hi))) + + stdout.write("\n#endif\n") + +###-------------------------------------------------------------------------- +### Group tables. + +class GroupTableClass (type): + def __new__(cls, name, supers, dict): + c = type.__new__(cls, name, supers, dict) + try: mode = c.mode + except AttributeError: pass + else: MODEMAP[c.mode] = c.run + return c + +class GroupTable (object): + __metaclass__ = GroupTableClass + keyword = 'group' + slots = [] + def __init__(me): + me.st = st = struct() + st.nextmp = 0 + st.mpmap = { None: 'NO_MP', 0: 'ZERO_MP' } + st.d = {} + me.st.name = None + me._names = [] + me._defs = set() + me._slotmap = dict([(s.name, s) for s in me.slots]) + me._headslots = [s for s in me.slots if s.headline] + def _flush(me): + if me.st.name is None: return + stdout.write("/* --- %s --- */\n" % me.st.name) + for s in me.slots: s.setup(me.st) + stdout.write("\nstatic %s c_%s = {" % (me.data_t, fix_name(me.st.name))) + sep = "\n " + for s in me.slots: + stdout.write(sep) + s.write(me.st) + sep = ",\n " + stdout.write("\n};\n\n") + me.st.d = {} + me.st.name = None + @classmethod + def run(cls, input): + me = cls() + write_header(me.mode, me.filename) + stdout.write('#include "%s"\n' % me.header) + write_preamble() + stdout.write("#define NO_MP { 0, 0, 0, 0, 0, 0 }\n\n") + write_banner("Group data") + stdout.write('\n') + with open(input) as file: + for line in file: + ff = line.split() + if not ff or ff[0].startswith('#'): continue + if ff[0] == 'alias': + if len(ff) != 3: raise Exception, "wrong number of alias arguments" + me._flush() + me._names.append((ff[1], ff[2])) + elif ff[0] == me.keyword: + if len(ff) < 2 or len(ff) > 2 + len(me._headslots): + raise Exception, "bad number of headline arguments" + me._flush() + me.st.name = name = ff[1] + me._defs.add(name) + me._names.append((name, name)) + for f, s in zip(ff[2:], me._headslots): s.set(me.st, f) + elif ff[0] in me._slotmap: + if len(ff) != 2: + raise Exception, "bad number of values for slot `%s'" % ff[0] + me._slotmap[ff[0]].set(me.st, ff[1]) + else: + raise Exception, "unknown keyword `%s'" % ff[0] + me._flush() + write_banner("Main table") + stdout.write("\nconst %s %s[] = {\n" % (me.entry_t, me.tabname)) + for a, n in me._names: + if n not in me._defs: + raise Exception, "alias `%s' refers to unknown group `%s'" % (a, n) + stdout.write(' { "%s", &c_%s },\n' % (a, fix_name(n))) + stdout.write(" { 0, 0 }\n};\n\n") + write_banner("That's all, folks") + +class BaseSlot (object): + def __init__(me, name, headline = False, omitp = None, allowp = None): + me.name = name + me.headline = headline + me._omitp = None + me._allowp = None + def set(me, st, value): + if me._allowp and not me._allowp(st, value): + raise Exception, "slot `%s' not allowed here" % me.name + st.d[me] = value + def setup(me, st): + if me not in st.d and (not me._omitp or not me._omitp(st)): + raise Exception, "missing slot `%s'" % me.name + +class EnumSlot (BaseSlot): + def __init__(me, name, prefix, values, **kw): + super(EnumSlot, me).__init__(name, **kw) + me._values = set(values) + me._prefix = prefix + def set(me, st, value): + if value not in me._values: + raise Exception, "invalid %s value `%s'" % (me.name, value) + super(EnumSlot, me).set(st, value) + def write(me, st): + try: stdout.write('%s_%s' % (me._prefix, st.d[me].upper())) + except KeyError: stdout.write('0') + +class MPSlot (BaseSlot): + def setup(me, st): + v = st.d.get(me) + if v not in st.mpmap: + write_limbs('v%d' % st.nextmp, v) + st.mpmap[v] = mp_body('v%d' % st.nextmp, v) + st.nextmp += 1 + def set(me, st, value): + super(MPSlot, me).set(st, long(value, 0)) + def write(me, st): + stdout.write(st.mpmap[st.d.get(me)]) + +class BinaryGroupTable (GroupTable): + mode = 'bintab' + filename = 'bintab.c' + header = 'bintab.h' + data_t = 'bindata' + entry_t = 'binentry' + tabname = 'bintab' + slots = [MPSlot('p'), MPSlot('q'), MPSlot('g')] + +class EllipticCurveTable (GroupTable): + mode = 'ectab' + filename = 'ectab.c' + header = 'ectab.h' + keyword = 'curve' + data_t = 'ecdata' + entry_t = 'ecentry' + tabname = 'ectab' + slots = [EnumSlot('type', 'FTAG', + ['prime', 'niceprime', 'binpoly', 'binnorm'], + headline = True), + MPSlot('p'), + MPSlot('beta', + allowp = lambda st, _: st.d['type'] == 'binnorm', + omitp = lambda st: st.d['type'] != 'binnorm'), + MPSlot('a'), MPSlot('b'), MPSlot('r'), MPSlot('h'), + MPSlot('gx'), MPSlot('gy')] + +class PrimeGroupTable (GroupTable): + mode = 'ptab' + filename = 'ptab.c' + header = 'ptab.h' + data_t = 'pdata' + entry_t = 'pentry' + tabname = 'ptab' + slots = [MPSlot('p'), MPSlot('q'), MPSlot('g')] + +###-------------------------------------------------------------------------- +### Main program. + +op = OP.OptionParser( + description = 'Generate multiprecision integer representations', + usage = 'usage: %prog [-t TYPEINFO] MODE [ARGS ...]', + version = 'Catacomb, version @VERSION@') +for shortopt, longopt, kw in [ + ('-t', '--typeinfo', dict( + action = 'store', metavar = 'PATH', dest = 'typeinfo', + help = 'alternative typeinfo file'))]: + op.add_option(shortopt, longopt, **kw) +op.set_defaults(typeinfo = './typeinfo.py') +opts, args = op.parse_args() + +if len(args) < 1: op.error('missing MODE') +mode = args[0] + +TC = TypeChoice(opts.typeinfo) + +try: modefunc = MODEMAP[mode] +except KeyError: op.error("unknown mode `%s'" % mode) +modefunc(*args[1:]) + +###----- That's all, folks --------------------------------------------------