# labtools, Copyright (C) 2017 Jerry Fowler and Paul Scheet.
# This program comes with ABSOLUTELY NO WARRANTY. It is licensed under
# GNU GPL Version 3. License and warranty may be viewed in the manual.
'''
A subclass of argparse.ArgumentParser that provides a collection of useful semantic
validations such as existing of a file named in a argument, etc.
The goal is to provide as much information about command-line errors as possible
before terminating, so as to reduce the user's repeated invocation of a command to
get the syntax right.
Typical usage would be:
>>> from labtools.ArgumentParserValidator import ArgumentParserValidator
>>> parser = ArgumentParserValidator()
>>> parser.add_argument('--foo', dest='foo')
>>> parser.add_argument('--numbers', dest='numbers')
>>> ...
>>> args = parser.parse_args('--foo /dev/null --numbers 1,2,3'.split())
>>> if parser.help_desired(args, program_version, sys.stderr):
>>> sys.exit(0)
>>> args.is_valid_file('foo', test_writeable=True)
>>> args.is_valie_range('numbers', number=3, bottom=0, top=5, exclusive=False)
>>> 0 == len(args.validation_errors())
'''
import os
import sys
import argparse
import logging
from labtools import const
from labtools import misc
from labtools import reflection
VERSION = 'version'
HELP = 'help'
LICENSE = 'license'
[docs]class ArgumentParserValidator(argparse.ArgumentParser):
'''Add a little semantic checking on top of the parser.
All the routines append to an error list and also return
the outcome of the given test, so that conditional validations
can be performed with reduced dependency on order of the tests.
'''
def __init__(self, *args, **kwargs):
if 'add_help' not in kwargs:
kwargs['add_help'] = False
if 'formatter_class' not in kwargs:
kwargs['formatter_class']=argparse.RawTextHelpFormatter
argparse.ArgumentParser.__init__(self, *args, **kwargs)
self.arguments = _Arguments()
self.add_argument(const.DASHES+LICENSE, dest=LICENSE, default=False, action='store_true',
help='''Display license and quit''')
self.add_argument(const.DASHES+VERSION, dest=VERSION, default=False, action='store_true',
help='''Display version and quit''')
self.add_argument(const.DASHES+HELP, dest=HELP, default=0, action='count',
help='''Display usage and quit (--help --help for more comprehensive help)''')
[docs] def parse_args(self, input=None, stderr=sys.stderr, namespace=None):
'''
Run the standard argparse.ArgumentParser and return an object containing the
parsed arguments plus a collection of methods to perform further semantic validation
of the arguments.
Some of the semantic validations are duplicates of existing ones, but (I claim) this
encapsulates them in an easy-to-use way.
Hack around problem in argparse:
argparse.ArgumentParser explicitly looks for sys.argv[0] if args is none
and then parse_args() looks for sys.argv[1:]
What should really happen is a patch suggested to argparse that curries
sys out of the code.
'''
if input and os.path.basename(input[0]) == self.prog:
input = input[1:]
try:
result = self.parse_known_args(input, namespace)
except AttributeError as ae:
result = None # Whatever happens here is already caught in _errors.
if result is None:
emergency = reflection.Namespace()
emergency.help = None
emergency.license = None
emergency.version = None
self.arguments.add_dict(emergency)
if not self.arguments.validation_errors():
self.arguments.add_error('Parsing %s failed, probably a parameter requires an argument' %
(const.SPACE.join(input)))
pass # Superclass already documented error
else:
args, argv = result
if argv:
self.error('Unexpected %s %s' % (misc.plural(len(argv), 'argument'),
const.SPACE.join(argv)), stderr)
self.arguments.add_dict(args)
return self.arguments
[docs] def help_desired(self, version, stream=sys.stderr):
'''
Standard test to see if license, version, or help is requested, print messages
to the stream if so, and return whether or not requested.
Sets versionstring
'''
self.arguments.versionstring = version
if self.arguments.license or self.arguments.version or self.arguments.help:
print(self.arguments.versionstring, file=stream)
if self.arguments.license:
print(const.EMPTY, file=stream)
print('Licensed to kill', file=stream)
if self.arguments.help:
print(const.EMPTY, file=stream)
self.print_help(file=stream)
if self.arguments.help > 1:
print(const.EMPTY, file=stream)
stream.write('No special help yet\n')
return True
else:
return False
[docs] def error(self, message, stderr=sys.stderr):
"""
Adds the *message* to the error list
"""
self.arguments.add_error(message)
@staticmethod
[docs] def valid_range(name, value, number=1, exclusive=True, bottom=0.0, top=1.0, is_float=True):
"""
Return whether value is within the given closed interval [*bottom*, *top*]
If *bottom* or *top* is None, perform a one-sided comparison.
If *exclusive*, test the open interval [*bottom*, *top*) (exclude upper bound).
value to be equal to the *top* bound.
The test is implemented as a call to eval(). *is_float* allows the correct formatting
of floating point numbers.
Also raise an error if *value* does not contain *number* comma-separated values (1 by default).
"""
comparator = '<' if exclusive else '<='
template = ''
argtuple = tuple()
if bottom is None and top is None:
raise UserWarning('Programmer mistake: bottom and top cannot both be None')
if bottom is not None:
template += '{0} %s '
argtuple += bottom, comparator
template += 'p'
if top is not None:
template += ' %s {0}'
argtuple += comparator, top
template = template.format('%03.1f' if is_float else '%d')
test = template % argtuple
# handle strings or float/ints
if str == type(value):
vals = value.split(const.COMMA)
else:
vals = [value]
try:
if len(vals) != number:
raise UserWarning('range error triggers errors.extend below')
for val in vals:
p = float(val) if is_float else int(val)
if not eval(test):
raise UserWarning('range error triggers errors.extend below')
except:
if number == 1:
message = 'a number'
else:
message = '%d comma-separated numbers' % number
return("%s was '%s' (must be %s, %s)" % (name, value, message, test))
return None
[docs]class _Arguments(argparse.Namespace):
'''
Return a namespace wrapped with a collection of special methods for semantic validation.
'''
def __init__(self):
super(argparse.Namespace, self).__init__()
self._errors = []
def add_dict(self, namespace):
for term in namespace.__dict__:
self.__dict__[term] = namespace.__dict__[term]
[docs] def arglist(self, subset=[]):
'''
return name=value for the defined arguments
'''
ignore = ['_errors', 'license', 'version', 'help']
return sorted(['%s=%r'%(a, self.__dict__[a])
for a in vars(self)
if (not subset or a in subset) and a not in ignore])
[docs] def is_valid_exclusive(self, *args):
'''Return True if exactly one of a list of exclusive arguments is provided.
'''
found = []
for arg in args:
if not arg in self.__dict__:
raise ValueError('%s is not a legal argument of the parser' % (arg))
elif self.__dict__[arg]:
found.append(arg)
if len(found) != 1:
if len(found):
self._errors.append('May choose only one of %s.' % ([args]))
else:
self._errors.append('Must choose one of %s.' % ([args]))
return False
return True
[docs] def is_valid_dir(self, parameter, test_writable=False):
'''Returns whether the string *parameter* is in the namespace, and if so, whether the pathname
specified by *parameter* exists and is a readable directory.
Optionally checks whether it is writable.
If *test_writable* is True, also test for writability.
'''
dir = self.__dict__[parameter] if parameter in self.__dict__ else None
if dir is None:
self._errors.append('''Directory parameter '%s' must be specified.''' % (parameter))
return False
if not os.path.isdir(dir):
self._errors.append('%s is not a directory.' % (dir))
return False
if test_writable and not os.access(dir, os.W_OK):
self._errors.append('%s is not writable.' % (dir))
return False
if not os.access(dir, os.R_OK):
self._errors.append('%s is not readable.' % (dir))
return False
return True
[docs] def is_valid_file(self, parameter, test_writable=False, accept_directory=False):
'''Returns whether the string *parameter* is in the namespace, and if so, whether the file
specified by *parameter* exists, and optionally whether it is writable or is a path
that could be created and written to.
test_writable=True implies accept a non-existent file in a writable directory.
accept_directory=True implies check to see if it's a valid directory.
'''
file = self.__dict__[parameter] if parameter in self.__dict__ else None
if file is None:
self._errors.append('''File parameter '%s' must be specified.''' % (parameter))
return False
return self.is_valid_filepath(file, test_writable, accept_directory)
def is_valid_filepath(self, file, test_writable=False, accept_directory=False):
ok, error = misc.is_valid_filepath(file, test_writable, accept_directory)
if error:
self._errors.append(error)
return ok
[docs] def is_valid_filelist(self, parameter, required_length=0):
'''Returns whether the string *parameter* is in the namespace, and if so, whether each file
in the list specified by *parameter* exists.
Optionally test whether the list contains at least *required_length*
'''
OK = True
size = 0
if not parameter in self.__dict__:
self._errors.append('file list parameter %s not specified' % (parameter))
for param in self.__dict__[parameter]:
size += 1
OK = OK and self.is_valid_filepath(param)
if size < required_length:
self._errors.append('At least %d files are required.' % (required_length))
return OK
[docs] def is_valid_range(self, parameter, number=1, exclusive=True, bottom=0.0, top=1.0, is_float=True):
'''Return True is *parameter* is in the given range from *bottom* to *top*, else False.
See ArgumentParserValidator.valid_range() for details.
Append an error message to *validation_errors()* if a violation is detected.
'''
value = self.__dict__[parameter]
error = ArgumentParserValidator.valid_range(parameter, value, number, exclusive, bottom, top, is_float)
if error:
self._errors.append(error)
return False
return True
[docs] def add_error(self, error):
'''Accrue the error to the error list. If *error* is a list, extend the existing list
rather than appending the list as an object.
'''
if str == type(error):
error = [error]
elif not list == type(error):
error = [str(error)]
self._errors.extend(error)
[docs] def validation_errors(self):
'''Return the list of accumulated validation errors (often used as a boolean test for errors)
'''
return self._errors
[docs] def print_errors(self, stream=sys.stderr):
'''Print the accumulated errors to the named *stream*, sys.stderr by default
'''
for error in self.validation_errors():
print(error, file=stream)
stream.flush()