backend.py: Introduce protocol for alternative locking schemes.
[chopwood] / backend.py
index 1725d7d..fda32e0 100644 (file)
@@ -25,6 +25,9 @@
 
 from __future__ import with_statement
 
+from auto import HOME
+import errno as E
+import itertools as I
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
@@ -36,7 +39,63 @@ import util as U
 CONF.DEFAULTS.update(
 
   ## A directory in which we can create lockfiles.
-  LOCKDIR = OS.path.join(ENV['HOME'], 'var', 'lock', 'chpwd'))
+  LOCKDIR = OS.path.join(HOME, 'lock'))
+
+###--------------------------------------------------------------------------
+### Utilities.
+
+def fill_in_fields(fno_user, fno_passwd, fno_map, user, passwd, args):
+  """
+  Return a vector of filled-in fields.
+
+  The FNO_... arguments give field numbers: FNO_USER and FNO_PASSWD give the
+  positions for the username and password fields, respectively; and FNO_MAP
+  is a sequence of (NAME, POS) pairs.  The USER and PASSWD arguments give the
+  actual user name and password values; ARGS are the remaining arguments,
+  maybe in the form `NAME=VALUE'.
+  """
+
+  ## Prepare the result vector, and set up some data structures.
+  n = 2 + len(fno_map)
+  fmap = {}
+  rmap = map(int, xrange(n))
+  ok = True
+  if fno_user >= n or fno_passwd >= n: ok = False
+  for k, i in fno_map:
+    fmap[k] = i
+    rmap[i] = "`%s'" % k
+    if i >= n: ok = False
+  if not ok:
+    raise U.ExpectedError, \
+        (500, "Fields specified aren't contiguous")
+
+  ## Prepare the new record's fields.
+  f = [None]*n
+  f[fno_user] = user
+  f[fno_passwd] = passwd
+
+  for a in args:
+    if '=' in a:
+      k, v = a.split('=', 1)
+      try: i = fmap[k]
+      except KeyError: raise U.ExpectedError, (400, "Unknown field `%s'" % k)
+    else:
+      for i in xrange(n):
+        if f[i] is None: break
+      else:
+        raise U.ExpectedError, (500, "All fields already populated")
+      v = a
+    if f[i] is not None:
+      raise U.ExpectedError, (400, "Field %s is already set" % rmap[i])
+    f[i] = v
+
+  ## Check that the vector of fields is properly set up.
+  for i in xrange(n):
+    if f[i] is None:
+      raise U.ExpectedError, (500, "Field %s is unset" % rmap[i])
+
+  ## Done.
+  return f
 
 ###--------------------------------------------------------------------------
 ### Protocol.
@@ -76,6 +135,8 @@ class BasicRecord (object):
     me._be = backend
   def write(me):
     me._be._update(me)
+  def remove(me):
+    me._be._remove(me)
 
 class TrivialRecord (BasicRecord):
   """
@@ -136,7 +197,7 @@ class FlatFileRecord (BasicRecord):
           raise U.ExpectedError, \
                 (500, "New `%s' field contains %s" % (k, what))
       fields[v] = val
-    return me._delim.join(fields)
+    return me._delim.join(fields) + '\n'
 
 class FlatFileBackend (object):
   """
@@ -147,10 +208,13 @@ class FlatFileBackend (object):
   specified by the DELIM constructor argument.
 
   The file is updated by writing a new version alongside, as `FILE.new', and
-  renaming it over the old version.  If a LOCK file is named then an
-  exclusive fcntl(2)-style lock is taken out on `LOCKDIR/LOCK' (creating the
-  file if necessary) during the update operation.  Use of a lockfile is
-  strongly recommended.
+  renaming it over the old version.  If a LOCK is provided then this is done
+  while holding a lock.  By default, an exclusive fcntl(2)-style lock is
+  taken out on `LOCKDIR/LOCK' (creating the file if necessary) during the
+  update operation, but subclasses can override the `dolocked' method to
+  provide alternative locking behaviour; the LOCK parameter is not
+  interpreted by any other methods.  Use of a lockfile is strongly
+  recommended.
 
   The DELIM constructor argument specifies the delimiter character used when
   splitting lines into fields.  The USER and PASSWD arguments give the field
@@ -181,8 +245,38 @@ class FlatFileBackend (object):
           return rec
     raise UnknownUser, user
 
-  def _update(me, rec):
-    """Update the record REC in the file."""
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right.
+    """
+
+    f = fill_in_fields(me._fmap['user'], me._fmap['passwd'],
+                       [(k[2:], i)
+                        for k, i in me._fmap.iteritems()
+                        if k.startswith('f_')],
+                       user, passwd, args)
+    r = FlatFileRecord(me._delim.join(f), me._delim, me._fmap, backend = me)
+    me._rewrite('create', r)
+
+  def _rewrite(me, op, rec):
+    """
+    Rewrite the file, according to OP.
+
+    The OP may be one of the following.
+
+    `create'            There must not be a record matching REC; add a new
+                        one.
+
+    `remove'            There must be a record matching REC: remove it.
+
+    `update'            There must be a record matching REC: write REC in its
+                        place.
+    """
 
     ## The main update function.
     def doit():
@@ -200,15 +294,26 @@ class FlatFileBackend (object):
 
         ## Copy the old file to the new one, changing the user's record if
         ## and when we encounter it.
+        found = False
         with OS.fdopen(fd, 'w') as f_out:
           with open(me._file) as f_in:
             for line in f_in:
               r = me._parse(line)
               if r.user != rec.user:
                 f_out.write(line)
+              elif op == 'create':
+                raise U.ExpectedError, \
+                    (500, "Record for `%s' already exists" % rec.user)
               else:
-                f_out.write(rec._format())
-                f_out.write('\n')
+                found = True
+                if op != 'remove': f_out.write(rec._format())
+          if found:
+            pass
+          elif op == 'create':
+            f_out.write(rec._format())
+          else:
+            raise U.ExpectedError, \
+                (500, "Record for `%s' not found" % rec.user)
 
         ## Update the permissions on the new file.  Don't try to fix the
         ## ownership (we shouldn't be running as root) or the group (the
@@ -227,18 +332,36 @@ class FlatFileBackend (object):
           try: OS.unlink(tmp)
           except: pass
 
-    ## If there's a locekfile, then acquire it around the meat of this
+    ## If there's a lockfile, then acquire it around the meat of this
     ## function; otherwise just do the job.
-    if me._lock is None:
-      doit()
-    else:
-      with U.lockfile(OS.path.join(CFG.LOCKDIR, me._lock), 5):
-        doit()
+    if me._lock is None: doit()
+    else: me.dolocked(me._lock, doit)
+
+  def dolocked(me, lock, func):
+    """
+    Call FUNC with the LOCK held.
+
+    Subclasses can override this method in order to provide alternative
+    locking functionality.
+    """
+    try: OS.mkdir(CFG.LOCKDIR)
+    except OSError, e:
+      if e.errno != E.EEXIST: raise
+    with U.lockfile(OS.path.join(CFG.LOCKDIR, lock), 5):
+      func()
 
   def _parse(me, line):
     """Convenience function for constructing a record."""
     return FlatFileRecord(line, me._delim, me._fmap, backend = me)
 
+  def _update(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('update', rec)
+
+  def _remove(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('remove', rec)
+
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
@@ -295,6 +418,36 @@ class DatabaseBackend (object):
       setattr(rec, 'f_' + f, v)
     return rec
 
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the named USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right, in the order given to the constructor.
+    """
+
+    tags = ['user', 'passwd'] + \
+        ['t_%d' % 0 for i in xrange(len(me._fields))]
+    f = fill_in_fields(0, 1, list(I.izip(me._fields, I.count(2))),
+                       user, passwd, args)
+    me._connect()
+    with me._db:
+      me._db.execute("INSERT INTO %s (%s) VALUES (%s)" %
+                     (me._table,
+                      ', '.join([me._user, me._passwd] + me._fields),
+                      ', '.join(['$%s' % t for t in tags])),
+                     **dict(I.izip(tags, f)))
+
+  def _remove(me, rec):
+    """Remove the record REC from the database."""
+    me._connect()
+    with me._db:
+      me._db.execute("DELETE FROM %s WHERE %s = $user" %
+                     (me._table, me._user),
+                     user = rec.user)
+
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()