X-Git-Url: https://git.distorted.org.uk/~mdw/chopwood/blobdiff_plain/cf7c527a2fbf15ab8f9c1064e9f1cda55b8bb2b6..HEAD:/cgi.py diff --git a/cgi.py b/cgi.py index d66bfc6..6ad9e75 100644 --- a/cgi.py +++ b/cgi.py @@ -51,7 +51,11 @@ CONF.DEFAULTS.update( ## A (maybe relative) URL for static content. By default this comes from ## the main script, but we hope that user agents cache it. - STATIC = _script_name + '/static') + STATIC = None) + +@CONF.hook +def set_static(): + if CFG.STATIC is None: CFG.STATIC = CFG.SCRIPT_NAME + '/static' ###-------------------------------------------------------------------------- ### Escaping and encoding. @@ -105,6 +109,7 @@ class HTTPOutput (O.FileOutput): """Constructor: initialize `headerp' flag.""" super(HTTPOutput, me).__init__(*args, **kw) me.headerp = False + me.warnings = [] def write(me, msg): """Output protocol: print a header if we've not written one already.""" @@ -123,6 +128,17 @@ class HTTPOutput (O.FileOutput): for h in O.http_headers(content_type = content_type, **kw): me.writeln(h) me.writeln('') + if METHOD == 'HEAD': + HEADER_DONE() + + def warn(me, msg): + """ + Report a warning message. + + The warning is stashed in a list where it can be retrieved using + `warnings'. + """ + me.warnings.append(msg) def cookie(name, value, **kw): """ @@ -270,7 +286,8 @@ def page(template, header = {}, title = 'Chopwood', **kw): header = dict(header, content_type = 'text/html') OUT.header(**header) format_tmpl(TMPL['wrapper.fhtml'], - title = title, payload = TMPL[template], **kw) + title = title, warnings = OUT.warnings, + payload = TMPL[template], **kw) ###-------------------------------------------------------------------------- ### Error reporting. @@ -311,12 +328,14 @@ def cgi_errors(hook = None): ### CGI input. ## Lots of global variables to be filled in by `cgiparse'. +METHOD = None COOKIE = {} SPECIAL = {} PARAM = [] PARAMDICT = {} PATH = [] SSLP = False +HEADER_DONE = lambda: None ## Regular expressions for splitting apart query and cookie strings. R_QSPLIT = RX.compile('[;&]') @@ -377,34 +396,35 @@ def cgiparse(): True if the client connection is carried over SSL or TLS. """ - global SSLP + global METHOD, SSLP def getenv(var): try: return ENV[var] except KeyError: raise U.ExpectedError, (500, "No `%s' supplied" % var) ## Yes, we want the request method. - method = getenv('REQUEST_METHOD') + METHOD = getenv('REQUEST_METHOD') ## Acquire the query string. - if method == 'GET': - q = getenv('QUERY_STRING') + if METHOD in ['GET', 'HEAD']: + q = ENV.get('QUERY_STRING', '') - elif method == 'POST': + elif METHOD == 'POST': ## We must read the query string from stdin. n = getenv('CONTENT_LENGTH') if not n.isdigit(): raise U.ExpectedError, (500, "Invalid CONTENT_LENGTH") n = int(n, 10) - if getenv('CONTENT_TYPE') != 'application/x-www-form-urlencoded': + ct = getenv('CONTENT_TYPE') + if ct != 'application/x-www-form-urlencoded': raise U.ExpectedError, (500, "Unexpected content type `%s'" % ct) q = SYS.stdin.read(n) if len(q) != n: raise U.ExpectedError, (500, "Failed to read correct length") else: - raise U.ExpectedError, (500, "Unexpected request method `%s'" % method) + raise U.ExpectedError, (500, "Unexpected request method `%s'" % METHOD) ## Populate the `SPECIAL', `PARAM' and `PARAMDICT' tables. seen = set() @@ -414,7 +434,8 @@ def cgiparse(): else: PARAM.append((k, v)) if k in seen: - del PARAMDICT[k] + try: del PARAMDICT[k] + except KeyError: pass else: PARAMDICT[k] = v seen.add(k) @@ -448,6 +469,11 @@ class Subcommand (SC.Subcommand): CGI parameters. """ + def __init__(me, name, contexts, desc, func, + methods = ['GET', 'POST'], *args, **kw): + super(Subcommand, me).__init__(name, contexts, desc, func, *args, **kw) + me.methods = set(methods) + def cgi(me, param, path): """ Invoke the subcommand given a collection of CGI parameters. @@ -464,6 +490,8 @@ class Subcommand (SC.Subcommand): the list of path elements is non-empty. """ + global HEADER_DONE + ## We're going to make a pass over the supplied parameters, and we'll ## check them off against the formal parameters as we go; so we'll need ## to be able to look them up. We'll also keep track of the ones we've @@ -477,6 +505,12 @@ class Subcommand (SC.Subcommand): want = {} kw = {} + ## Check the request method against the permitted list. + meth = METHOD + if meth == 'HEAD': meth = 'GET' + if meth not in me.methods: + raise U.ExpectedError, (500, "Unexpected request method `%s'" % METHOD) + def set_value(k, v): """Set a simple value: we shouldn't see multiple values.""" if k in kw: