chpwd.css: Make style match `distorted.org.uk' general house style.
[chopwood] / backend.py
index 1967cda..fda32e0 100644 (file)
@@ -25,6 +25,9 @@
 
 from __future__ import with_statement
 
 
 from __future__ import with_statement
 
+from auto import HOME
+import errno as E
+import itertools as I
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
@@ -36,7 +39,63 @@ import util as U
 CONF.DEFAULTS.update(
 
   ## A directory in which we can create lockfiles.
 CONF.DEFAULTS.update(
 
   ## A directory in which we can create lockfiles.
-  LOCKDIR = OS.path.join(ENV['HOME'], 'var', 'lock', 'chpwd'))
+  LOCKDIR = OS.path.join(HOME, 'lock'))
+
+###--------------------------------------------------------------------------
+### Utilities.
+
+def fill_in_fields(fno_user, fno_passwd, fno_map, user, passwd, args):
+  """
+  Return a vector of filled-in fields.
+
+  The FNO_... arguments give field numbers: FNO_USER and FNO_PASSWD give the
+  positions for the username and password fields, respectively; and FNO_MAP
+  is a sequence of (NAME, POS) pairs.  The USER and PASSWD arguments give the
+  actual user name and password values; ARGS are the remaining arguments,
+  maybe in the form `NAME=VALUE'.
+  """
+
+  ## Prepare the result vector, and set up some data structures.
+  n = 2 + len(fno_map)
+  fmap = {}
+  rmap = map(int, xrange(n))
+  ok = True
+  if fno_user >= n or fno_passwd >= n: ok = False
+  for k, i in fno_map:
+    fmap[k] = i
+    rmap[i] = "`%s'" % k
+    if i >= n: ok = False
+  if not ok:
+    raise U.ExpectedError, \
+        (500, "Fields specified aren't contiguous")
+
+  ## Prepare the new record's fields.
+  f = [None]*n
+  f[fno_user] = user
+  f[fno_passwd] = passwd
+
+  for a in args:
+    if '=' in a:
+      k, v = a.split('=', 1)
+      try: i = fmap[k]
+      except KeyError: raise U.ExpectedError, (400, "Unknown field `%s'" % k)
+    else:
+      for i in xrange(n):
+        if f[i] is None: break
+      else:
+        raise U.ExpectedError, (500, "All fields already populated")
+      v = a
+    if f[i] is not None:
+      raise U.ExpectedError, (400, "Field %s is already set" % rmap[i])
+    f[i] = v
+
+  ## Check that the vector of fields is properly set up.
+  for i in xrange(n):
+    if f[i] is None:
+      raise U.ExpectedError, (500, "Field %s is unset" % rmap[i])
+
+  ## Done.
+  return f
 
 ###--------------------------------------------------------------------------
 ### Protocol.
 
 ###--------------------------------------------------------------------------
 ### Protocol.
@@ -76,6 +135,8 @@ class BasicRecord (object):
     me._be = backend
   def write(me):
     me._be._update(me)
     me._be = backend
   def write(me):
     me._be._update(me)
+  def remove(me):
+    me._be._remove(me)
 
 class TrivialRecord (BasicRecord):
   """
 
 class TrivialRecord (BasicRecord):
   """
@@ -147,10 +208,13 @@ class FlatFileBackend (object):
   specified by the DELIM constructor argument.
 
   The file is updated by writing a new version alongside, as `FILE.new', and
   specified by the DELIM constructor argument.
 
   The file is updated by writing a new version alongside, as `FILE.new', and
-  renaming it over the old version.  If a LOCK file is named then an
-  exclusive fcntl(2)-style lock is taken out on `LOCKDIR/LOCK' (creating the
-  file if necessary) during the update operation.  Use of a lockfile is
-  strongly recommended.
+  renaming it over the old version.  If a LOCK is provided then this is done
+  while holding a lock.  By default, an exclusive fcntl(2)-style lock is
+  taken out on `LOCKDIR/LOCK' (creating the file if necessary) during the
+  update operation, but subclasses can override the `dolocked' method to
+  provide alternative locking behaviour; the LOCK parameter is not
+  interpreted by any other methods.  Use of a lockfile is strongly
+  recommended.
 
   The DELIM constructor argument specifies the delimiter character used when
   splitting lines into fields.  The USER and PASSWD arguments give the field
 
   The DELIM constructor argument specifies the delimiter character used when
   splitting lines into fields.  The USER and PASSWD arguments give the field
@@ -181,8 +245,38 @@ class FlatFileBackend (object):
           return rec
     raise UnknownUser, user
 
           return rec
     raise UnknownUser, user
 
-  def _update(me, rec):
-    """Update the record REC in the file."""
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right.
+    """
+
+    f = fill_in_fields(me._fmap['user'], me._fmap['passwd'],
+                       [(k[2:], i)
+                        for k, i in me._fmap.iteritems()
+                        if k.startswith('f_')],
+                       user, passwd, args)
+    r = FlatFileRecord(me._delim.join(f), me._delim, me._fmap, backend = me)
+    me._rewrite('create', r)
+
+  def _rewrite(me, op, rec):
+    """
+    Rewrite the file, according to OP.
+
+    The OP may be one of the following.
+
+    `create'            There must not be a record matching REC; add a new
+                        one.
+
+    `remove'            There must be a record matching REC: remove it.
+
+    `update'            There must be a record matching REC: write REC in its
+                        place.
+    """
 
     ## The main update function.
     def doit():
 
     ## The main update function.
     def doit():
@@ -200,14 +294,26 @@ class FlatFileBackend (object):
 
         ## Copy the old file to the new one, changing the user's record if
         ## and when we encounter it.
 
         ## Copy the old file to the new one, changing the user's record if
         ## and when we encounter it.
+        found = False
         with OS.fdopen(fd, 'w') as f_out:
           with open(me._file) as f_in:
             for line in f_in:
               r = me._parse(line)
               if r.user != rec.user:
                 f_out.write(line)
         with OS.fdopen(fd, 'w') as f_out:
           with open(me._file) as f_in:
             for line in f_in:
               r = me._parse(line)
               if r.user != rec.user:
                 f_out.write(line)
+              elif op == 'create':
+                raise U.ExpectedError, \
+                    (500, "Record for `%s' already exists" % rec.user)
               else:
               else:
-                f_out.write(rec._format())
+                found = True
+                if op != 'remove': f_out.write(rec._format())
+          if found:
+            pass
+          elif op == 'create':
+            f_out.write(rec._format())
+          else:
+            raise U.ExpectedError, \
+                (500, "Record for `%s' not found" % rec.user)
 
         ## Update the permissions on the new file.  Don't try to fix the
         ## ownership (we shouldn't be running as root) or the group (the
 
         ## Update the permissions on the new file.  Don't try to fix the
         ## ownership (we shouldn't be running as root) or the group (the
@@ -228,16 +334,34 @@ class FlatFileBackend (object):
 
     ## If there's a lockfile, then acquire it around the meat of this
     ## function; otherwise just do the job.
 
     ## If there's a lockfile, then acquire it around the meat of this
     ## function; otherwise just do the job.
-    if me._lock is None:
-      doit()
-    else:
-      with U.lockfile(OS.path.join(CFG.LOCKDIR, me._lock), 5):
-        doit()
+    if me._lock is None: doit()
+    else: me.dolocked(me._lock, doit)
+
+  def dolocked(me, lock, func):
+    """
+    Call FUNC with the LOCK held.
+
+    Subclasses can override this method in order to provide alternative
+    locking functionality.
+    """
+    try: OS.mkdir(CFG.LOCKDIR)
+    except OSError, e:
+      if e.errno != E.EEXIST: raise
+    with U.lockfile(OS.path.join(CFG.LOCKDIR, lock), 5):
+      func()
 
   def _parse(me, line):
     """Convenience function for constructing a record."""
     return FlatFileRecord(line, me._delim, me._fmap, backend = me)
 
 
   def _parse(me, line):
     """Convenience function for constructing a record."""
     return FlatFileRecord(line, me._delim, me._fmap, backend = me)
 
+  def _update(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('update', rec)
+
+  def _remove(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('remove', rec)
+
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
@@ -294,6 +418,36 @@ class DatabaseBackend (object):
       setattr(rec, 'f_' + f, v)
     return rec
 
       setattr(rec, 'f_' + f, v)
     return rec
 
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the named USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right, in the order given to the constructor.
+    """
+
+    tags = ['user', 'passwd'] + \
+        ['t_%d' % 0 for i in xrange(len(me._fields))]
+    f = fill_in_fields(0, 1, list(I.izip(me._fields, I.count(2))),
+                       user, passwd, args)
+    me._connect()
+    with me._db:
+      me._db.execute("INSERT INTO %s (%s) VALUES (%s)" %
+                     (me._table,
+                      ', '.join([me._user, me._passwd] + me._fields),
+                      ', '.join(['$%s' % t for t in tags])),
+                     **dict(I.izip(tags, f)))
+
+  def _remove(me, rec):
+    """Remove the record REC from the database."""
+    me._connect()
+    with me._db:
+      me._db.execute("DELETE FROM %s WHERE %s = $user" %
+                     (me._table, me._user),
+                     user = rec.user)
+
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()