| # Copyright 2017 The Abseil Authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Decorator and context manager for saving and restoring flag values. |
| |
| There are many ways to save and restore. Always use the most convenient method |
| for a given use case. |
| |
| Here are examples of each method. They all call do_stuff() while FLAGS.someflag |
| is temporarily set to 'foo'. |
| |
| # Use a decorator which can optionally override flags via arguments. |
| @flagsaver.flagsaver(someflag='foo') |
| def some_func(): |
| do_stuff() |
| |
| # Use a decorator which does not override flags itself. |
| @flagsaver.flagsaver |
| def some_func(): |
| FLAGS.someflag = 'foo' |
| do_stuff() |
| |
| # Use a context manager which can optionally override flags via arguments. |
| with flagsaver.flagsaver(someflag='foo'): |
| do_stuff() |
| |
| # Save and restore the flag values yourself. |
| saved_flag_values = flagsaver.save_flag_values() |
| try: |
| FLAGS.someflag = 'foo' |
| do_stuff() |
| finally: |
| flagsaver.restore_flag_values(saved_flag_values) |
| |
| We save and restore a shallow copy of each Flag object's __dict__ attribute. |
| This preserves all attributes of the flag, such as whether or not it was |
| overridden from its default value. |
| |
| WARNING: Currently a flag that is saved and then deleted cannot be restored. An |
| exception will be raised. However if you *add* a flag after saving flag values, |
| and then restore flag values, the added flag will be deleted with no errors. |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| import inspect |
| |
| from absl import flags |
| import six |
| |
| FLAGS = flags.FLAGS |
| |
| |
| def flagsaver(*args, **kwargs): |
| """The main flagsaver interface. See module doc for usage.""" |
| if not args: |
| return _FlagOverrider(**kwargs) |
| elif len(args) == 1: |
| if kwargs: |
| raise ValueError( |
| "It's invalid to specify both positional and keyword parameters.") |
| func = args[0] |
| if inspect.isclass(func): |
| raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') |
| return _wrap(func, {}) |
| else: |
| raise ValueError( |
| "It's invalid to specify more than one positional parameters.") |
| |
| |
| def save_flag_values(flag_values=FLAGS): |
| """Returns copy of flag values as a dict. |
| |
| Args: |
| flag_values: FlagValues, the FlagValues instance with which the flag will |
| be saved. This should almost never need to be overridden. |
| Returns: |
| Dictionary mapping keys to values. Keys are flag names, values are |
| corresponding __dict__ members. E.g. {'key': value_dict, ...}. |
| """ |
| return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} |
| |
| |
| def restore_flag_values(saved_flag_values, flag_values=FLAGS): |
| """Restores flag values based on the dictionary of flag values. |
| |
| Args: |
| saved_flag_values: {'flag_name': value_dict, ...} |
| flag_values: FlagValues, the FlagValues instance from which the flag will |
| be restored. This should almost never need to be overridden. |
| """ |
| new_flag_names = list(flag_values) |
| for name in new_flag_names: |
| saved = saved_flag_values.get(name) |
| if saved is None: |
| # If __dict__ was not saved delete "new" flag. |
| delattr(flag_values, name) |
| else: |
| if flag_values[name].value != saved['_value']: |
| flag_values[name].value = saved['_value'] # Ensure C++ value is set. |
| flag_values[name].__dict__ = saved |
| |
| |
| def _wrap(func, overrides): |
| """Creates a wrapper function that saves/restores flag values. |
| |
| Args: |
| func: function object - This will be called between saving flags and |
| restoring flags. |
| overrides: {str: object} - Flag names mapped to their values. These flags |
| will be set after saving the original flag state. |
| |
| Returns: |
| return value from func() |
| """ |
| @functools.wraps(func) |
| def _flagsaver_wrapper(*args, **kwargs): |
| """Wrapper function that saves and restores flags.""" |
| with _FlagOverrider(**overrides): |
| return func(*args, **kwargs) |
| return _flagsaver_wrapper |
| |
| |
| class _FlagOverrider(object): |
| """Overrides flags for the duration of the decorated function call. |
| |
| It also restores all original values of flags after decorated method |
| completes. |
| """ |
| |
| def __init__(self, **overrides): |
| self._overrides = overrides |
| self._saved_flag_values = None |
| |
| def __call__(self, func): |
| if inspect.isclass(func): |
| raise TypeError('flagsaver cannot be applied to a class.') |
| return _wrap(func, self._overrides) |
| |
| def __enter__(self): |
| self._saved_flag_values = save_flag_values(FLAGS) |
| try: |
| for name, value in six.iteritems(self._overrides): |
| setattr(FLAGS, name, value) |
| except: |
| # It may fail because of flag validators. |
| restore_flag_values(self._saved_flag_values, FLAGS) |
| raise |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| restore_flag_values(self._saved_flag_values, FLAGS) |
| |
| |
| def _copy_flag_dict(flag): |
| """Returns a copy of the flag object's __dict__. |
| |
| It's mostly a shallow copy of the __dict__, except it also does a shallow |
| copy of the validator list. |
| |
| Args: |
| flag: flags.Flag, the flag to copy. |
| |
| Returns: |
| A copy of the flag object's __dict__. |
| """ |
| copy = flag.__dict__.copy() |
| copy['_value'] = flag.value # Ensure correct restore for C++ flags. |
| copy['validators'] = list(flag.validators) |
| return copy |