Source code for ArgumentParserValidator

# labtools, Copyright (C) 2017 Jerry Fowler and Paul Scheet.
# This program comes with ABSOLUTELY NO WARRANTY. It is licensed under
# GNU GPL Version 3. License and warranty may be viewed in the manual.
'''
A subclass of argparse.ArgumentParser that provides a collection of useful semantic
validations such as existing of a file named in a argument, etc.

The goal is to provide as much information about command-line errors as possible
before terminating, so as to reduce the user's repeated invocation of a command to
get the syntax right.

Typical usage would be:

>>> from labtools.ArgumentParserValidator import ArgumentParserValidator
>>> parser = ArgumentParserValidator()
>>> parser.add_argument('--foo', dest='foo')
>>> parser.add_argument('--numbers', dest='numbers')
>>> ...
>>> args = parser.parse_args('--foo /dev/null --numbers 1,2,3'.split())
>>> if parser.help_desired(args, program_version, sys.stderr):
>>>    sys.exit(0)
>>> args.is_valid_file('foo', test_writeable=True)
>>> args.is_valie_range('numbers', number=3, bottom=0, top=5, exclusive=False)
>>> 0 == len(args.validation_errors())
'''

import os
import sys

import argparse
import logging
from labtools import const
from labtools import misc
from labtools import reflection

VERSION = 'version'
HELP = 'help'
LICENSE = 'license'

[docs]class ArgumentParserValidator(argparse.ArgumentParser): '''Add a little semantic checking on top of the parser. All the routines append to an error list and also return the outcome of the given test, so that conditional validations can be performed with reduced dependency on order of the tests. ''' def __init__(self, *args, **kwargs): if 'add_help' not in kwargs: kwargs['add_help'] = False if 'formatter_class' not in kwargs: kwargs['formatter_class']=argparse.RawTextHelpFormatter argparse.ArgumentParser.__init__(self, *args, **kwargs) self.arguments = _Arguments() self.add_argument(const.DASHES+LICENSE, dest=LICENSE, default=False, action='store_true', help='''Display license and quit''') self.add_argument(const.DASHES+VERSION, dest=VERSION, default=False, action='store_true', help='''Display version and quit''') self.add_argument(const.DASHES+HELP, dest=HELP, default=0, action='count', help='''Display usage and quit (--help --help for more comprehensive help)''')
[docs] def parse_args(self, input=None, stderr=sys.stderr, namespace=None): ''' Run the standard argparse.ArgumentParser and return an object containing the parsed arguments plus a collection of methods to perform further semantic validation of the arguments. Some of the semantic validations are duplicates of existing ones, but (I claim) this encapsulates them in an easy-to-use way. Hack around problem in argparse: argparse.ArgumentParser explicitly looks for sys.argv[0] if args is none and then parse_args() looks for sys.argv[1:] What should really happen is a patch suggested to argparse that curries sys out of the code. ''' if input and os.path.basename(input[0]) == self.prog: input = input[1:] try: result = self.parse_known_args(input, namespace) except AttributeError as ae: result = None # Whatever happens here is already caught in _errors. if result is None: emergency = reflection.Namespace() emergency.help = None emergency.license = None emergency.version = None self.arguments.add_dict(emergency) if not self.arguments.validation_errors(): self.arguments.add_error('Parsing %s failed, probably a parameter requires an argument' % (const.SPACE.join(input))) pass # Superclass already documented error else: args, argv = result if argv: self.error('Unexpected %s %s' % (misc.plural(len(argv), 'argument'), const.SPACE.join(argv)), stderr) self.arguments.add_dict(args) return self.arguments
[docs] def help_desired(self, version, stream=sys.stderr): ''' Standard test to see if license, version, or help is requested, print messages to the stream if so, and return whether or not requested. Sets versionstring ''' self.arguments.versionstring = version if self.arguments.license or self.arguments.version or self.arguments.help: print(self.arguments.versionstring, file=stream) if self.arguments.license: print(const.EMPTY, file=stream) print('Licensed to kill', file=stream) if self.arguments.help: print(const.EMPTY, file=stream) self.print_help(file=stream) if self.arguments.help > 1: print(const.EMPTY, file=stream) stream.write('No special help yet\n') return True else: return False
[docs] def error(self, message, stderr=sys.stderr): """ Adds the *message* to the error list """ self.arguments.add_error(message)
@staticmethod
[docs] def valid_range(name, value, number=1, exclusive=True, bottom=0.0, top=1.0, is_float=True): """ Return whether value is within the given closed interval [*bottom*, *top*] If *bottom* or *top* is None, perform a one-sided comparison. If *exclusive*, test the open interval [*bottom*, *top*) (exclude upper bound). value to be equal to the *top* bound. The test is implemented as a call to eval(). *is_float* allows the correct formatting of floating point numbers. Also raise an error if *value* does not contain *number* comma-separated values (1 by default). """ comparator = '<' if exclusive else '<=' template = '' argtuple = tuple() if bottom is None and top is None: raise UserWarning('Programmer mistake: bottom and top cannot both be None') if bottom is not None: template += '{0} %s ' argtuple += bottom, comparator template += 'p' if top is not None: template += ' %s {0}' argtuple += comparator, top template = template.format('%03.1f' if is_float else '%d') test = template % argtuple # handle strings or float/ints if str == type(value): vals = value.split(const.COMMA) else: vals = [value] try: if len(vals) != number: raise UserWarning('range error triggers errors.extend below') for val in vals: p = float(val) if is_float else int(val) if not eval(test): raise UserWarning('range error triggers errors.extend below') except: if number == 1: message = 'a number' else: message = '%d comma-separated numbers' % number return("%s was '%s' (must be %s, %s)" % (name, value, message, test)) return None
[docs]class _Arguments(argparse.Namespace): ''' Return a namespace wrapped with a collection of special methods for semantic validation. ''' def __init__(self): super(argparse.Namespace, self).__init__() self._errors = [] def add_dict(self, namespace): for term in namespace.__dict__: self.__dict__[term] = namespace.__dict__[term]
[docs] def arglist(self, subset=[]): ''' return name=value for the defined arguments ''' ignore = ['_errors', 'license', 'version', 'help'] return sorted(['%s=%r'%(a, self.__dict__[a]) for a in vars(self) if (not subset or a in subset) and a not in ignore])
[docs] def is_valid_exclusive(self, *args): '''Return True if exactly one of a list of exclusive arguments is provided. ''' found = [] for arg in args: if not arg in self.__dict__: raise ValueError('%s is not a legal argument of the parser' % (arg)) elif self.__dict__[arg]: found.append(arg) if len(found) != 1: if len(found): self._errors.append('May choose only one of %s.' % ([args])) else: self._errors.append('Must choose one of %s.' % ([args])) return False return True
[docs] def is_valid_dir(self, parameter, test_writable=False): '''Returns whether the string *parameter* is in the namespace, and if so, whether the pathname specified by *parameter* exists and is a readable directory. Optionally checks whether it is writable. If *test_writable* is True, also test for writability. ''' dir = self.__dict__[parameter] if parameter in self.__dict__ else None if dir is None: self._errors.append('''Directory parameter '%s' must be specified.''' % (parameter)) return False if not os.path.isdir(dir): self._errors.append('%s is not a directory.' % (dir)) return False if test_writable and not os.access(dir, os.W_OK): self._errors.append('%s is not writable.' % (dir)) return False if not os.access(dir, os.R_OK): self._errors.append('%s is not readable.' % (dir)) return False return True
[docs] def is_valid_file(self, parameter, test_writable=False, accept_directory=False): '''Returns whether the string *parameter* is in the namespace, and if so, whether the file specified by *parameter* exists, and optionally whether it is writable or is a path that could be created and written to. test_writable=True implies accept a non-existent file in a writable directory. accept_directory=True implies check to see if it's a valid directory. ''' file = self.__dict__[parameter] if parameter in self.__dict__ else None if file is None: self._errors.append('''File parameter '%s' must be specified.''' % (parameter)) return False return self.is_valid_filepath(file, test_writable, accept_directory)
def is_valid_filepath(self, file, test_writable=False, accept_directory=False): ok, error = misc.is_valid_filepath(file, test_writable, accept_directory) if error: self._errors.append(error) return ok
[docs] def is_valid_filelist(self, parameter, required_length=0): '''Returns whether the string *parameter* is in the namespace, and if so, whether each file in the list specified by *parameter* exists. Optionally test whether the list contains at least *required_length* ''' OK = True size = 0 if not parameter in self.__dict__: self._errors.append('file list parameter %s not specified' % (parameter)) for param in self.__dict__[parameter]: size += 1 OK = OK and self.is_valid_filepath(param) if size < required_length: self._errors.append('At least %d files are required.' % (required_length)) return OK
[docs] def is_valid_range(self, parameter, number=1, exclusive=True, bottom=0.0, top=1.0, is_float=True): '''Return True is *parameter* is in the given range from *bottom* to *top*, else False. See ArgumentParserValidator.valid_range() for details. Append an error message to *validation_errors()* if a violation is detected. ''' value = self.__dict__[parameter] error = ArgumentParserValidator.valid_range(parameter, value, number, exclusive, bottom, top, is_float) if error: self._errors.append(error) return False return True
[docs] def add_error(self, error): '''Accrue the error to the error list. If *error* is a list, extend the existing list rather than appending the list as an object. ''' if str == type(error): error = [error] elif not list == type(error): error = [str(error)] self._errors.extend(error)
[docs] def validation_errors(self): '''Return the list of accumulated validation errors (often used as a boolean test for errors) ''' return self._errors
[docs] def print_errors(self, stream=sys.stderr): '''Print the accumulated errors to the named *stream*, sys.stderr by default ''' for error in self.validation_errors(): print(error, file=stream) stream.flush()