beancount.utils

Generic utility packages and functions.

beancount.utils.bisect_key

A version of bisect that accepts a custom key function, like the sorting ones do.

beancount.utils.bisect_key.bisect_left_with_key(sequence, value, key=None)

Find the last element before the given value in a sorted list.

Parameters:
  • sequence (Sequence[T]) – A sorted sequence of elements.

  • value (U) – The value to search for.

  • key (Callable[[T], U] | None) – An optional function used to extract the value from the elements of sequence.

Returns:
  • int – Return the index. May return None.

Source code in beancount/utils/bisect_key.py
def bisect_left_with_key(
    sequence: Sequence[T], value: U, key: Callable[[T], U] | None = None
) -> int:
    """Find the last element before the given value in a sorted list.

    Args:
      sequence: A sorted sequence of elements.
      value: The value to search for.
      key: An optional function used to extract the value from the elements of
        sequence.
    Returns:
      Return the index. May return None.
    """

    # see overloads, if no key function is given, T = U
    keyfunc: Callable[[T], U] = key if key is not None else lambda x: x  # type: ignore[assignment,return-value]

    lo = 0
    hi = len(sequence)

    while lo < hi:
        mid = (lo + hi) // 2
        # Python does not yet have a built-in way to add some "Comparable" bound to U
        if keyfunc(sequence[mid]) < value:  # type: ignore[operator]
            lo = mid + 1
        else:
            hi = mid
    return lo

beancount.utils.bisect_key.bisect_right_with_key(a, x, key, lo=0, hi=None)

Like bisect.bisect_right, but with a key lookup parameter.

Parameters:
  • a – The list to search in.

  • x – The element to search for.

  • key – A function, to extract the value from the list.

  • lo – The smallest index to search.

  • hi – The largest index to search.

Returns:
  • As in bisect.bisect_right, an element from list 'a'.

Source code in beancount/utils/bisect_key.py
def bisect_right_with_key(a, x, key, lo=0, hi=None):
    """Like bisect.bisect_right, but with a key lookup parameter.

    Args:
      a: The list to search in.
      x: The element to search for.
      key: A function, to extract the value from the list.
      lo: The smallest index to search.
      hi: The largest index to search.
    Returns:
      As in bisect.bisect_right, an element from list 'a'.
    """

    if lo < 0:
        raise ValueError("lo must be non-negative")
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo + hi) // 2
        if x < key(a[mid]):
            hi = mid
        else:
            lo = mid + 1
    return lo

beancount.utils.defdict

An instance of collections.defaultdict whose factory accepts a key.

Note: This really ought to be an enhancement to Python itself. I should bother adding this in eventually.

beancount.utils.defdict.ImmutableDictWithDefault (dict)

An immutable dict which returns a default value for missing keys.

This differs from a defaultdict in that it does not insert a missing default value when one is materialized (from a missing fetch), and furthermore, the set method is make unavailable to prevent mutation beyond construction.

beancount.utils.defdict.ImmutableDictWithDefault.__setitem__(self, key, value) special

Disallow mutating the dict in the usual way.

Source code in beancount/utils/defdict.py
def __setitem__(self, key, value):
    """Disallow mutating the dict in the usual way."""
    raise NotImplementedError

beancount.utils.defdict.ImmutableDictWithDefault.get(self, key, _=None)

Return the value for key if key is in the dictionary, else default.

Source code in beancount/utils/defdict.py
def get(self, key, _=None):
    return self.__getitem__(key)

beancount.utils.encryption

Support for encrypted tests.

beancount.utils.encryption.is_encrypted_file(filename)

Return true if the given filename contains an encrypted file.

Parameters:
  • filename (str | Path) – A path string.

Returns:
  • bool – A boolean, true if the file contains an encrypted file.

Source code in beancount/utils/encryption.py
def is_encrypted_file(filename: str | Path) -> bool:
    """Return true if the given filename contains an encrypted file.

    Args:
      filename: A path string.
    Returns:
      A boolean, true if the file contains an encrypted file.
    """
    _, ext = path.splitext(filename)
    if ext == ".gpg":
        return True
    if ext == ".asc":
        # python will still raise UnicodeDecodeError if file content is not in ascii encoding
        with contextlib.suppress(UnicodeDecodeError):
            with open(filename, encoding="ascii") as encfile:
                head = encfile.read(1024)
                if re.search("--BEGIN PGP MESSAGE--", head):
                    return True
    return False

beancount.utils.encryption.is_gpg_installed()

Return true if GPG 1.4.x or 2.x are installed, which is what we use and support.

Source code in beancount/utils/encryption.py
def is_gpg_installed() -> bool:
    """Return true if GPG 1.4.x or 2.x are installed, which is what we use and support."""
    try:
        pipe = subprocess.Popen(
            ["gpg", "--version"],
            shell=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        out, err = pipe.communicate()
        version_text = out.decode("utf8")
        return pipe.returncode == 0 and bool(
            re.match(r"gpg \(GnuPG\) (1\.4|2)\.", version_text)
        )
    except OSError:
        return False

beancount.utils.encryption.read_encrypted_file(filename)

Decrypt and read an encrypted file without temporary storage.

Parameters:
  • filename (str | Path) – A string, the path to the encrypted file.

Returns:
  • str – A string, the contents of the file.

Exceptions:
  • OSError – If we could not properly decrypt the file.

Source code in beancount/utils/encryption.py
def read_encrypted_file(filename: str | Path) -> str:
    """Decrypt and read an encrypted file without temporary storage.

    Args:
      filename: A string, the path to the encrypted file.
    Returns:
      A string, the contents of the file.
    Raises:
      OSError: If we could not properly decrypt the file.
    """
    command = ["gpg", "--batch", "--decrypt", path.realpath(filename)]
    pipe = subprocess.Popen(
        command, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    contents, errors = pipe.communicate()
    if pipe.returncode != 0:
        raise OSError(
            "Could not decrypt file ({}): {}".format(pipe.returncode, errors.decode("utf8"))
        )
    return contents.decode("utf-8")

beancount.utils.file_utils

File utilities.

beancount.utils.file_utils.find_files(fords, ignore_dirs=('.hg', '.svn', '.git'), ignore_files=('.DS_Store',))

Enumerate the files under the given directories, stably.

Invalid file or directory names will be logged to the error log.

Parameters:
  • fords – A list of strings, file or directory names.

  • ignore_dirs – A list of strings, filenames or directories to be ignored.

  • ignore_files – a sequence of strings, filenames to be ignored

Yields: Strings, full filenames from the given roots.

Source code in beancount/utils/file_utils.py
def find_files(fords, ignore_dirs=(".hg", ".svn", ".git"), ignore_files=(".DS_Store",)):
    """Enumerate the files under the given directories, stably.

    Invalid file or directory names will be logged to the error log.

    Args:
      fords: A list of strings, file or directory names.
      ignore_dirs: A list of strings, filenames or directories to be ignored.
      ignore_files: a sequence of strings, filenames to be ignored
    Yields:
      Strings, full filenames from the given roots.
    """
    if isinstance(fords, str):
        fords = [fords]
    assert isinstance(fords, (list, tuple))
    for ford in fords:
        if path.isdir(ford):
            for root, dirs, filenames in os.walk(ford):
                dirs[:] = sorted(dirname for dirname in dirs if dirname not in ignore_dirs)
                for filename in sorted(filenames):
                    if filename in ignore_files:
                        continue
                    yield path.join(root, filename)
        elif path.isfile(ford) or path.islink(ford):
            yield ford
        elif not path.exists(ford):
            logging.error("File or directory '{}' does not exist.".format(ford))

beancount.utils.file_utils.guess_file_format(filename, default=None)

Guess the file format from the filename.

Parameters:
  • filename – A string, the name of the file. This can be None.

  • default – default value if no known file extensions match.

Returns:
  • A string, the extension of the format, without a leading period.

Source code in beancount/utils/file_utils.py
def guess_file_format(filename, default=None):
    """Guess the file format from the filename.

    Args:
      filename: A string, the name of the file. This can be None.
      default: default value if no known file extensions match.
    Returns:
      A string, the extension of the format, without a leading period.
    """
    if filename:
        if filename.endswith(".txt") or filename.endswith(".text"):
            format = "text"
        elif filename.endswith(".csv"):
            format = "csv"
        elif filename.endswith(".html") or filename.endswith(".xhtml"):
            format = "html"
        else:
            format = default
    else:
        format = default
    return format

beancount.utils.file_utils.path_greedy_split(filename)

Split a path, returning the longest possible extension.

Parameters:
  • filename – A string, the filename to split.

Returns:
  • A pair of basename, extension (which includes the leading period).

Source code in beancount/utils/file_utils.py
def path_greedy_split(filename):
    """Split a path, returning the longest possible extension.

    Args:
      filename: A string, the filename to split.
    Returns:
      A pair of basename, extension (which includes the leading period).
    """
    basename = path.basename(filename)
    index = basename.find(".")
    if index == -1:
        extension = None
    else:
        extension = basename[index:]
        basename = basename[:index]
    return (path.join(path.dirname(filename), basename), extension)

beancount.utils.file_utils.touch_file(filename, *otherfiles)

Touch a file and wait until its timestamp has been changed.

Parameters:
  • filename – A string path, the name of the file to touch.

  • otherfiles – A list of other files to ensure the timestamp is beyond of.

Source code in beancount/utils/file_utils.py
def touch_file(filename, *otherfiles):
    """Touch a file and wait until its timestamp has been changed.

    Args:
      filename: A string path, the name of the file to touch.
      otherfiles: A list of other files to ensure the timestamp is beyond of.
    """
    # Note: You could set os.stat_float_times() but then the main function would
    # have to set that up as well. It doesn't help so much, however, since
    # filesystems tend to have low resolutions, e.g. one second.
    orig_mtime_ns = max(
        os.stat(minfile).st_mtime_ns for minfile in (filename,) + otherfiles
    )
    delay_secs = 0.05
    while True:
        with open(filename, "a", encoding="utf-8"):
            os.utime(filename)
        time.sleep(delay_secs)
        new_stat = os.stat(filename)
        if new_stat.st_mtime_ns > orig_mtime_ns:
            break

beancount.utils.import_utils

Utilities for importing symbols programmatically.

beancount.utils.import_utils.import_symbol(dotted_name)

Import a symbol in an arbitrary module.

Parameters:
  • dotted_name – A dotted path to a symbol.

Returns:
  • The object referenced by the given name.

Exceptions:
  • ImportError – If the module not not be imported.

  • AttributeError – If the symbol could not be found in the module.

Source code in beancount/utils/import_utils.py
def import_symbol(dotted_name):
    """Import a symbol in an arbitrary module.

    Args:
      dotted_name: A dotted path to a symbol.
    Returns:
      The object referenced by the given name.
    Raises:
      ImportError: If the module not not be imported.
      AttributeError: If the symbol could not be found in the module.
    """
    comps = dotted_name.split(".")
    module_name = ".".join(comps[:-1])
    symbol_name = comps[-1]
    module = importlib.import_module(module_name)
    return getattr(module, symbol_name)

beancount.utils.invariants

Functions to register auxiliary functions on a class' methods to check for invariants.

This is intended to be used in a test, whereby your test will setup a class to automatically run invariant verification functions before and after each function call, to ensure some extra sanity checks that wouldn't be used in non-tests.

Example: Instrument the Inventory class with the check_inventory_invariants() function.

def setUp(module): instrument_invariants(Inventory, check_inventory_invariants, check_inventory_invariants)

def tearDown(module): uninstrument_invariants(Inventory)

beancount.utils.invariants.instrument_invariants(klass, prefun, postfun)

Instrument the class 'klass' with pre/post invariant checker functions.

Parameters:
  • klass – A class object, whose methods to be instrumented.

  • prefun – A function that checks invariants pre-call.

  • postfun – A function that checks invariants pre-call.

Source code in beancount/utils/invariants.py
def instrument_invariants(klass, prefun, postfun):
    """Instrument the class 'klass' with pre/post invariant
    checker functions.

    Args:
      klass: A class object, whose methods to be instrumented.
      prefun: A function that checks invariants pre-call.
      postfun: A function that checks invariants pre-call.
    """
    instrumented = {}
    for attrname, object_ in klass.__dict__.items():
        if attrname.startswith("_"):
            continue
        if not isinstance(object_, types.FunctionType):
            continue
        instrumented[attrname] = object_
        setattr(klass, attrname, invariant_check(object_, prefun, postfun))
    klass.__instrumented = instrumented

beancount.utils.invariants.invariant_check(method, prefun, postfun)

Decorate a method with the pre/post invariant checkers.

Parameters:
  • method – An unbound method to instrument.

  • prefun – A function that checks invariants pre-call.

  • postfun – A function that checks invariants post-call.

Returns:
  • An unbound method, decorated.

Source code in beancount/utils/invariants.py
def invariant_check(method, prefun, postfun):
    """Decorate a method with the pre/post invariant checkers.

    Args:
      method: An unbound method to instrument.
      prefun: A function that checks invariants pre-call.
      postfun: A function that checks invariants post-call.
    Returns:
      An unbound method, decorated.
    """
    reentrant = []

    def new_method(self, *args, **kw):
        reentrant.append(None)
        if len(reentrant) == 1:
            prefun(self)
        result = method(self, *args, **kw)
        if len(reentrant) == 1:
            postfun(self)
        reentrant.pop()
        return result

    return new_method

beancount.utils.invariants.uninstrument_invariants(klass)

Undo the instrumentation for invariants.

Parameters:
  • klass – A class object, whose methods to be uninstrumented.

Source code in beancount/utils/invariants.py
def uninstrument_invariants(klass):
    """Undo the instrumentation for invariants.

    Args:
      klass: A class object, whose methods to be uninstrumented.
    """
    instrumented = getattr(klass, "__instrumented", None)
    if instrumented:
        for attrname, object_ in instrumented.items():
            setattr(klass, attrname, object_)
    del klass.__instrumented

beancount.utils.memo

Memoization utilities.

beancount.utils.memo.memoize_recent_fileobj(function, cache_filename, expiration=None)

Memoize recent calls to the given function which returns a file object.

The results of the cache expire after some time.

Parameters:
  • function – A callable object.

  • cache_filename – A string, the path to the database file to cache to.

  • expiration – The time during which the results will be kept valid. Use 'None' to never expire the cache (this is the default).

Returns:
  • A memoized version of the function.

Source code in beancount/utils/memo.py
def memoize_recent_fileobj(function, cache_filename, expiration=None):
    """Memoize recent calls to the given function which returns a file object.

    The results of the cache expire after some time.

    Args:
      function: A callable object.
      cache_filename: A string, the path to the database file to cache to.
      expiration: The time during which the results will be kept valid. Use
        'None' to never expire the cache (this is the default).
    Returns:
      A memoized version of the function.
    """
    urlcache = shelve.open(cache_filename, "c")
    urlcache.lock = threading.Lock()  # Note: 'shelve' is not thread-safe.

    @functools.wraps(function)
    def memoized(*args, **kw):
        # Encode the arguments, including a date string in order to invalidate
        # results over some time.
        md5 = hashlib.md5()
        md5.update(str(args).encode("utf-8"))
        md5.update(str(sorted(kw.items())).encode("utf-8"))

        hash_ = md5.hexdigest()
        time_now = now()
        try:
            with urlcache.lock:
                time_orig, contents = urlcache[hash_]
            if expiration is not None and (time_now - time_orig) > expiration:
                raise KeyError
        except KeyError:
            fileobj = function(*args, **kw)
            if fileobj:
                contents = fileobj.read()
                with urlcache.lock:
                    urlcache[hash_] = (time_now, contents)
            else:
                contents = None

        return io.BytesIO(contents) if contents else None

    return memoized

beancount.utils.memo.now()

Indirection on datetime.datetime.now() for testing.

Source code in beancount/utils/memo.py
def now():
    "Indirection on datetime.datetime.now() for testing."
    return datetime.datetime.now()

beancount.utils.misc_utils

Generic utility packages and functions.

beancount.utils.misc_utils.LineFileProxy

A file object that will delegate writing full lines to another logging function. This may be used for writing data to a logging level without having to worry about lines.

beancount.utils.misc_utils.LineFileProxy.__init__(self, line_writer, prefix=None, write_newlines=False) special

Construct a new line delegator file object proxy.

Parameters:
  • line_writer – A callable function, used to write to the delegated output.

  • prefix – An optional string, the prefix to insert before every line.

  • write_newlines – A boolean, true if we should output the newline characters.

Source code in beancount/utils/misc_utils.py
def __init__(self, line_writer, prefix=None, write_newlines=False):
    """Construct a new line delegator file object proxy.

    Args:
      line_writer: A callable function, used to write to the delegated output.
      prefix: An optional string, the prefix to insert before every line.
      write_newlines: A boolean, true if we should output the newline characters.
    """
    self.line_writer = line_writer
    self.prefix = prefix
    self.write_newlines = write_newlines
    self.data = []

beancount.utils.misc_utils.LineFileProxy.close(self)

Close the line delegator.

Source code in beancount/utils/misc_utils.py
def close(self):
    """Close the line delegator."""
    self.flush()

beancount.utils.misc_utils.LineFileProxy.flush(self)

Flush the data to the line writer.

Source code in beancount/utils/misc_utils.py
def flush(self):
    """Flush the data to the line writer."""
    data = "".join(self.data)
    if data:
        lines = data.splitlines()
        self.data = [lines.pop(-1)] if data[-1] != "\n" else []
        for line in lines:
            if self.prefix:
                line = self.prefix + line
            if self.write_newlines:
                line += "\n"
            self.line_writer(line)

beancount.utils.misc_utils.LineFileProxy.write(self, data)

Write some string data to the output.

Parameters:
  • data – A string, with or without newlines.

Source code in beancount/utils/misc_utils.py
def write(self, data):
    """Write some string data to the output.

    Args:
      data: A string, with or without newlines.
    """
    if "\n" in data:
        self.data.append(data)
        self.flush()
    else:
        self.data.append(data)

beancount.utils.misc_utils.escape_string(string)

Escape quotes and backslashes in payee and narration.

Parameters:
  • string – Any string.

Returns. The input string, with offending characters replaced.

Source code in beancount/utils/misc_utils.py
def escape_string(string):
    """Escape quotes and backslashes in payee and narration.

    Args:
      string: Any string.
    Returns.
      The input string, with offending characters replaced.
    """
    return string.replace("\\", r"\\").replace('"', r"\"")

beancount.utils.misc_utils.filter_type(elist, types)

Filter the given list to yield only instances of the given types.

Parameters:
  • elist – A sequence of elements.

  • types – A sequence of types to include in the output list.

Yields: Each element, if it is an instance of 'types'.

Source code in beancount/utils/misc_utils.py
def filter_type(elist, types):
    """Filter the given list to yield only instances of the given types.

    Args:
      elist: A sequence of elements.
      types: A sequence of types to include in the output list.
    Yields:
      Each element, if it is an instance of 'types'.
    """
    for element in elist:
        if not isinstance(element, types):
            continue
        yield element

beancount.utils.misc_utils.get_screen_height()

Return the height of the terminal that runs this program.

Returns:
  • An integer, the number of characters the screen is high. Return 0 if the terminal cannot be initialized.

Source code in beancount/utils/misc_utils.py
def get_screen_height():
    """Return the height of the terminal that runs this program.

    Returns:
      An integer, the number of characters the screen is high.
      Return 0 if the terminal cannot be initialized.
    """
    return _get_screen_value("lines", 0)

beancount.utils.misc_utils.get_screen_width()

Return the width of the terminal that runs this program.

Returns:
  • An integer, the number of characters the screen is wide. Return 0 if the terminal cannot be initialized.

Source code in beancount/utils/misc_utils.py
def get_screen_width():
    """Return the width of the terminal that runs this program.

    Returns:
      An integer, the number of characters the screen is wide.
      Return 0 if the terminal cannot be initialized.
    """
    return _get_screen_value("cols", 0)

beancount.utils.misc_utils.groupby(keyfun, elements)

Group the elements as a dict of lists, where the key is computed using the function 'keyfun'.

Parameters:
  • keyfun – A callable, used to obtain the group key from each element.

  • elements – An iterable of the elements to group.

Returns:
  • A dict of key to list of sequences.

Source code in beancount/utils/misc_utils.py
def groupby(keyfun, elements):
    """Group the elements as a dict of lists, where the key is computed using the
    function 'keyfun'.

    Args:
      keyfun: A callable, used to obtain the group key from each element.
      elements: An iterable of the elements to group.
    Returns:
      A dict of key to list of sequences.
    """
    # Note: We could allow a custom aggregation function. Another option is
    # provide another method to reduce the list values of a dict, but that can
    # be accomplished using a dict comprehension.
    grouped = defaultdict(list)
    for element in elements:
        grouped[keyfun(element)].append(element)
    return grouped

beancount.utils.misc_utils.import_curses()

Try to import the 'curses' module. (This is used here in order to override for tests.)

Returns:
  • The curses module, if it was possible to import it.

Exceptions:
  • ImportError – If the module could not be imported.

Source code in beancount/utils/misc_utils.py
def import_curses():
    """Try to import the 'curses' module.
    (This is used here in order to override for tests.)

    Returns:
      The curses module, if it was possible to import it.
    Raises:
      ImportError: If the module could not be imported.
    """
    # Note: There's a recipe for getting terminal size on Windows here, without
    # curses, I should probably implement that at some point:
    # https://stackoverflow.com/questions/263890/how-do-i-find-the-width-height-of-a-terminal-window
    # Also, consider just using 'blessings' instead, which provides this across
    # multiple platforms.

    import curses

    return curses

beancount.utils.misc_utils.is_sorted(iterable, key=<function <lambda> at 0x7adffd1f3100>, cmp=<function <lambda> at 0x7adffd1f31a0>)

Return true if the sequence is sorted.

Parameters:
  • iterable – An iterable sequence.

  • key – A function to extract the quantity by which to sort.

  • cmp – A function that compares two elements of a sequence.

Returns:
  • A boolean, true if the sequence is sorted.

Source code in beancount/utils/misc_utils.py
def is_sorted(iterable, key=lambda x: x, cmp=lambda x, y: x <= y):
    """Return true if the sequence is sorted.

    Args:
      iterable: An iterable sequence.
      key: A function to extract the quantity by which to sort.
      cmp: A function that compares two elements of a sequence.
    Returns:
      A boolean, true if the sequence is sorted.
    """
    iterator = map(key, iterable)
    prev = next(iterator)
    for element in iterator:
        if not cmp(prev, element):
            return False
        prev = element
    return True

beancount.utils.misc_utils.log_time(operation_name, log_timings, indent=0)

A context manager that times the block and logs it to info level.

Parameters:
  • operation_name – A string, a label for the name of the operation.

  • log_timings – A function to write log messages to. If left to None, no timings are written (this becomes a no-op).

  • indent – An integer, the indentation level for the format of the timing line. This is useful if you're logging timing to a hierarchy of operations.

Yields: The start time of the operation.

Source code in beancount/utils/misc_utils.py
@contextlib.contextmanager
def log_time(operation_name, log_timings, indent=0):
    """A context manager that times the block and logs it to info level.

    Args:
      operation_name: A string, a label for the name of the operation.
      log_timings: A function to write log messages to. If left to None,
        no timings are written (this becomes a no-op).
      indent: An integer, the indentation level for the format of the timing
        line. This is useful if you're logging timing to a hierarchy of
        operations.
    Yields:
      The start time of the operation.
    """
    time1 = time()
    yield time1
    time2 = time()
    if log_timings:
        log_timings(
            "Operation: {:48} Time: {}{:6.0f} ms".format(
                "'{}'".format(operation_name), "      " * indent, (time2 - time1) * 1000
            )
        )

beancount.utils.misc_utils.skipiter(iterable, num_skip)

Skip some elements from an iterator.

Parameters:
  • iterable – An iterator.

  • num_skip – The number of elements in the period.

Yields: Elements from the iterable, with num_skip elements skipped. For example, skipiter(range(10), 3) yields [0, 3, 6, 9].

Source code in beancount/utils/misc_utils.py
def skipiter(iterable, num_skip):
    """Skip some elements from an iterator.

    Args:
      iterable: An iterator.
      num_skip: The number of elements in the period.
    Yields:
      Elements from the iterable, with num_skip elements skipped.
      For example, skipiter(range(10), 3) yields [0, 3, 6, 9].
    """
    assert num_skip > 0
    sit = iter(iterable)
    while 1:
        try:
            value = next(sit)
        except StopIteration:
            return
        yield value
        for _ in range(num_skip - 1):
            try:
                next(sit)
            except StopIteration:
                return

beancount.utils.misc_utils.sorted_uniquify(iterable, keyfunc=None, last=False)

Given a sequence of elements, sort and remove duplicates of the given key. Keep either the first or the last (by key) element of a sequence of key-identical elements. This does not maintain the ordering of the original elements, they are returned sorted (by key) instead.

Parameters:
  • iterable – An iterable sequence.

  • keyfunc – A function that extracts from the elements the sort key to use and uniquify on. If left unspecified, the identify function is used and the uniquification occurs on the elements themselves.

  • last – A boolean, True if we should keep the last item of the same keys. Otherwise keep the first.

Yields: Elements from the iterable.

Source code in beancount/utils/misc_utils.py
def sorted_uniquify(iterable, keyfunc=None, last=False):
    """Given a sequence of elements, sort and remove duplicates of the given key.
    Keep either the first or the last (by key) element of a sequence of
    key-identical elements. This does _not_ maintain the ordering of the
    original elements, they are returned sorted (by key) instead.

    Args:
      iterable: An iterable sequence.
      keyfunc: A function that extracts from the elements the sort key
        to use and uniquify on. If left unspecified, the identify function
        is used and the uniquification occurs on the elements themselves.
      last: A boolean, True if we should keep the last item of the same keys.
        Otherwise keep the first.
    Yields:
      Elements from the iterable.
    """
    if keyfunc is None:
        keyfunc = lambda x: x
    if last:
        prev_obj = UNSET
        prev_key = UNSET
        for obj in sorted(iterable, key=keyfunc):
            key = keyfunc(obj)
            if key != prev_key and prev_obj is not UNSET:
                yield prev_obj
            prev_obj = obj
            prev_key = key
        if prev_obj is not UNSET:
            yield prev_obj
    else:
        prev_key = UNSET
        for obj in sorted(iterable, key=keyfunc):
            key = keyfunc(obj)
            if key != prev_key:
                yield obj
                prev_key = key

beancount.utils.misc_utils.uniquify(iterable, keyfunc=None, last=False)

Given a sequence of elements, remove duplicates of the given key. Keep either the first or the last element of a sequence of key-identical elements. Order is maintained as much as possible. This does maintain the ordering of the original elements, they are returned in the same order as the original elements.

Parameters:
  • iterable – An iterable sequence.

  • keyfunc – A function that extracts from the elements the sort key to use and uniquify on. If left unspecified, the identify function is used and the uniquification occurs on the elements themselves.

  • last – A boolean, True if we should keep the last item of the same keys. Otherwise keep the first.

Yields: Elements from the iterable.

Source code in beancount/utils/misc_utils.py
def uniquify(iterable, keyfunc=None, last=False):
    """Given a sequence of elements, remove duplicates of the given key. Keep either
    the first or the last element of a sequence of key-identical elements. Order
    is maintained as much as possible. This does maintain the ordering of the
    original elements, they are returned in the same order as the original
    elements.

    Args:
      iterable: An iterable sequence.
      keyfunc: A function that extracts from the elements the sort key
        to use and uniquify on. If left unspecified, the identify function
        is used and the uniquification occurs on the elements themselves.
      last: A boolean, True if we should keep the last item of the same keys.
        Otherwise keep the first.
    Yields:
      Elements from the iterable.
    """
    if keyfunc is None:
        keyfunc = lambda x: x
    seen = set()
    if last:
        unique_reversed_list = []
        for obj in reversed(iterable):
            key = keyfunc(obj)
            if key not in seen:
                seen.add(key)
                unique_reversed_list.append(obj)
        yield from reversed(unique_reversed_list)
    else:
        for obj in iterable:
            key = keyfunc(obj)
            if key not in seen:
                seen.add(key)
                yield obj

beancount.utils.pager

Code to write output to a pager.

This module contains an object accumulates lines up to a minimum and then decides whether to flush them to the original output directly if under the threshold (no pager) or creates a pager and flushes the lines to it if above the threshold and then forwards all future lines to it. The purpose of this object is to pipe output to a pager only if the number of lines to be printed exceeds a minimum number of lines.

The contextmanager is intended to be used to pipe output to a pager and wait on the pager to complete before continuing. Simply write to the file object and upon exit we close the file object. This also silences broken pipe errors triggered by the user exiting the sub-process, and recovers from a failing pager command by just using stdout.

beancount.utils.pager.ConditionalPager

A proxy file for a pager that only creates a pager after a minimum number of lines has been printed to it.

beancount.utils.pager.ConditionalPager.__enter__(self) special

Initialize the context manager and return this instance as it.

Source code in beancount/utils/pager.py
def __enter__(self):
    """Initialize the context manager and return this instance as it."""

    # The file and pipe object we're writing to. This gets set after the
    # number of accumulated lines reaches the threshold.
    if self.minlines:
        self.file = None
        self.pipe = None
    else:
        self.file, self.pipe = create_pager(self.command, self.default_file)

    # Lines accumulated before the threshold.
    self.accumulated_data = []
    self.accumulated_lines = 0

    # Return this object to be used as the context manager itself.
    return self

beancount.utils.pager.ConditionalPager.__exit__(self, type, value, unused_traceback) special

Context manager exit. This flushes the output to our output file.

Parameters:
  • type – Optional exception type, as per context managers.

  • value – Optional exception value, as per context managers.

  • unused_traceback – Optional trace.

Source code in beancount/utils/pager.py
def __exit__(self, type, value, unused_traceback):
    """Context manager exit. This flushes the output to our output file.

    Args:
      type: Optional exception type, as per context managers.
      value: Optional exception value, as per context managers.
      unused_traceback: Optional trace.
    """
    try:
        if self.file:
            # Flush the output file and close it.
            self.file.flush()
        else:
            # Oops... we never reached the threshold. Flush the accumulated
            # output to the file.
            self.flush_accumulated(self.default_file)

        # Wait for the subprocess (if we have one).
        if self.pipe:
            self.file.close()
            self.pipe.wait()

    # Absorb broken pipes that may occur on flush or close above.
    except BrokenPipeError:
        return True

    # Absorb broken pipes.
    if isinstance(value, BrokenPipeError):
        return True
    elif value:
        raise

beancount.utils.pager.ConditionalPager.__init__(self, command, minlines=None) special

Create a conditional pager.

Parameters:
  • command – A string, the shell command to run as a pager.

  • minlines – If set, the number of lines under which you should not bother starting a pager. This avoids kicking off a pager if the screen is high enough to render the contents. If the value is unset, always starts a pager (which is fine behavior too).

Source code in beancount/utils/pager.py
def __init__(self, command, minlines=None):
    """Create a conditional pager.

    Args:
      command: A string, the shell command to run as a pager.
      minlines: If set, the number of lines under which you should not bother starting
        a pager. This avoids kicking off a pager if the screen is high enough to
        render the contents. If the value is unset, always starts a pager (which is
        fine behavior too).
    """
    self.command = command
    self.minlines = minlines
    self.default_file = (
        codecs.getwriter("utf-8")(sys.stdout.buffer)
        if hasattr(sys.stdout, "buffer")
        else sys.stdout
    )

beancount.utils.pager.ConditionalPager.flush_accumulated(self, file)

Flush the existing lines to the newly created pager. This also disabled the accumulator.

Parameters:
  • file – A file object to flush the accumulated data to.

Source code in beancount/utils/pager.py
def flush_accumulated(self, file):
    """Flush the existing lines to the newly created pager.
    This also disabled the accumulator.

    Args:
      file: A file object to flush the accumulated data to.
    """
    if self.accumulated_data:
        write = file.write
        for data in self.accumulated_data:
            write(data)
    self.accumulated_data = None
    self.accumulated_lines = None

beancount.utils.pager.ConditionalPager.write(self, data)

Write the data out. Overridden from the file object interface.

Parameters:
  • data – A string, data to write to the output.

Source code in beancount/utils/pager.py
def write(self, data):
    """Write the data out. Overridden from the file object interface.

    Args:
      data: A string, data to write to the output.
    """
    if self.file is None:
        # Accumulate the new lines.
        self.accumulated_lines += data.count("\n")
        self.accumulated_data.append(data)

        # If we've reached the threshold, create a file.
        if self.accumulated_lines > self.minlines:
            self.file, self.pipe = create_pager(self.command, self.default_file)
            self.flush_accumulated(self.file)
    else:
        # We've already created a pager subprocess... flush the lines to it.
        self.file.write(data)
        # try:
        # except BrokenPipeError:
        #     # Make sure we don't barf on __exit__().
        #     self.file = self.pipe = None
        #     raise

beancount.utils.pager.create_pager(command, file)

Try to create and return a pager subprocess.

Parameters:
  • command – A string, the shell command to run as a pager.

  • file – The file object for the pager write to. This is also used as a default if we failed to create the pager subprocess.

Returns:
  • A pair of (file, pipe), a file object and an optional subprocess.Popen instance to wait on. The pipe instance may be set to None if we failed to create a subprocess.

Source code in beancount/utils/pager.py
def create_pager(command, file):
    """Try to create and return a pager subprocess.

    Args:
      command: A string, the shell command to run as a pager.
      file: The file object for the pager write to. This is also used as a
        default if we failed to create the pager subprocess.
    Returns:
      A pair of (file, pipe), a file object and an optional subprocess.Popen instance
      to wait on. The pipe instance may be set to None if we failed to create a subprocess.
    """

    if command is None:
        command = os.environ.get("PAGER", DEFAULT_PAGER)
    if not command:
        command = DEFAULT_PAGER

    pipe = None

    # In case of using 'less', make sure the charset is set properly. In theory
    # you could override this by setting PAGER to "LESSCHARSET=utf-8 less" but
    # this shouldn't affect other programs and is unlikely to cause problems, so
    # we set it here to make default behavior work for most people (we always
    # write UTF-8).
    env = os.environ.copy()
    env["LESSCHARSET"] = "utf-8"

    try:
        pipe = subprocess.Popen(
            command, shell=True, stdin=subprocess.PIPE, stdout=file, env=env
        )
    except OSError as exc:
        logging.error("Invalid pager: {}".format(exc))
    else:
        stdin_wrapper = io.TextIOWrapper(pipe.stdin, "utf-8")
        file = stdin_wrapper
    return file, pipe

beancount.utils.pager.flush_only(fileobj)

A contextmanager around a file object that does not close the file.

This is used to return a context manager on a file object but not close it. We flush it instead. This is useful in order to provide an alternative to a pager class as above.

Parameters:
  • fileobj – A file object, to remain open after running the context manager.

Yields: A context manager that yields this object.

Source code in beancount/utils/pager.py
@contextlib.contextmanager
def flush_only(fileobj):
    """A contextmanager around a file object that does not close the file.

    This is used to return a context manager on a file object but not close it.
    We flush it instead. This is useful in order to provide an alternative to a
    pager class as above.

    Args:
      fileobj: A file object, to remain open after running the context manager.
    Yields:
      A context manager that yields this object.
    """
    try:
        yield fileobj
    finally:
        fileobj.flush()

beancount.utils.table

Table rendering.

beancount.utils.table.Table (tuple)

Table(columns, header, body)

beancount.utils.table.Table.__getnewargs__(self) special

Return self as a plain tuple. Used by copy and pickle.

Source code in beancount/utils/table.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)

beancount.utils.table.Table.__new__(_cls, columns, header, body) special staticmethod

Create new instance of Table(columns, header, body)

beancount.utils.table.Table.__repr__(self) special

Return a nicely formatted representation string

Source code in beancount/utils/table.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

beancount.utils.table.attribute_to_title(fieldname)

Convert programming id into readable field name.

Parameters:
  • fieldname – A string, a programming ids, such as 'book_value'.

Returns:
  • A readable string, such as 'Book Value.'

Source code in beancount/utils/table.py
def attribute_to_title(fieldname):
    """Convert programming id into readable field name.

    Args:
      fieldname: A string, a programming ids, such as 'book_value'.
    Returns:
      A readable string, such as 'Book Value.'
    """
    return fieldname.replace("_", " ").title()

beancount.utils.table.compute_table_widths(rows)

Compute the max character widths of a list of rows.

Parameters:
  • rows – A list of rows, which are sequences of strings.

Returns:
  • A list of integers, the maximum widths required to render the columns of this table.

Exceptions:
  • IndexError – If the rows are of different lengths.

Source code in beancount/utils/table.py
def compute_table_widths(rows):
    """Compute the max character widths of a list of rows.

    Args:
      rows: A list of rows, which are sequences of strings.
    Returns:
      A list of integers, the maximum widths required to render the columns of
      this table.
    Raises:
      IndexError: If the rows are of different lengths.
    """
    row_iter = iter(rows)
    first_row = next(row_iter)
    num_columns = len(first_row)
    column_widths = [len(cell) for cell in first_row]
    for row in row_iter:
        for i, cell in enumerate(row):
            if not isinstance(cell, str):
                cell = str(cell)
            cell_len = len(cell)
            column_widths[i] = max(cell_len, column_widths[i])

        if i + 1 != num_columns:
            raise IndexError("Invalid number of rows")
    return column_widths

beancount.utils.table.create_table(rows, field_spec=None)

Convert a list of tuples to an table report object.

Parameters:
  • rows – A list of tuples.

  • field_spec – A list of strings, or a list of (FIELDNAME-OR-INDEX, HEADER, FORMATTER-FUNCTION) triplets, that selects a subset of the fields is to be rendered as well as their ordering. If this is a dict, the values are functions to call on the fields to render them. If a function is set to None, we will just call str() on the field.

Returns:
  • A Table instance.

Source code in beancount/utils/table.py
def create_table(rows, field_spec=None):
    """Convert a list of tuples to an table report object.

    Args:
      rows: A list of tuples.
      field_spec: A list of strings, or a list of
        (FIELDNAME-OR-INDEX, HEADER, FORMATTER-FUNCTION)
        triplets, that selects a subset of the fields is to be rendered as well
        as their ordering. If this is a dict, the values are functions to call
        on the fields to render them. If a function is set to None, we will just
        call str() on the field.
    Returns:
      A Table instance.
    """
    # Normalize field_spec to a dict.
    if field_spec is None:
        namedtuple_class = type(rows[0])
        field_spec = [(field, None, None) for field in namedtuple_class._fields]

    elif isinstance(field_spec, (list, tuple)):
        new_field_spec = []
        for field in field_spec:
            if isinstance(field, tuple):
                assert len(field) <= 3, field
                if len(field) == 1:
                    field = field[0]
                    new_field_spec.append((field, None, None))
                elif len(field) == 2:
                    field, header = field
                    new_field_spec.append((field, header, None))
                elif len(field) == 3:
                    new_field_spec.append(field)
            else:
                if isinstance(field, str):
                    title = attribute_to_title(field)
                elif isinstance(field, int):
                    title = "Field {}".format(field)
                else:
                    raise ValueError("Invalid type for column name")
                new_field_spec.append((field, title, None))

        field_spec = new_field_spec

    # Ensure a nicely formatted header.
    field_spec = [
        (
            (name, attribute_to_title(name), formatter)
            if header_ is None
            else (name, header_, formatter)
        )
        for (name, header_, formatter) in field_spec
    ]

    assert isinstance(field_spec, list), field_spec
    assert all(len(x) == 3 for x in field_spec), field_spec

    # Compute the column names.
    columns = [name for (name, _, __) in field_spec]

    # Compute the table header.
    header = [header_column for (_, header_column, __) in field_spec]

    # Compute the table body.
    body = []
    for row in rows:
        body_row = []
        for name, _, formatter in field_spec:
            if isinstance(name, str):
                value = getattr(row, name)
            elif isinstance(name, int):
                value = row[name]
            else:
                raise ValueError("Invalid type for column name")
            if value is not None:
                if formatter is not None:
                    value = formatter(value)
                else:
                    value = str(value)
            else:
                value = ""
            body_row.append(value)
        body.append(body_row)

    return Table(columns, header, body)

beancount.utils.table.render_table(table_, output, output_format, css_id=None, css_class=None)

Render the given table to the output file object in the requested format.

The table gets written out to the 'output' file.

Parameters:
  • table_ – An instance of Table.

  • output – A file object you can write to.

  • output_format – A string, the format to write the table to, either 'csv', 'txt' or 'html'.

  • css_id – A string, an optional CSS id for the table object (only used for HTML).

  • css_class – A string, an optional CSS class for the table object (only used for HTML).

Source code in beancount/utils/table.py
def render_table(table_, output, output_format, css_id=None, css_class=None):
    """Render the given table to the output file object in the requested format.

    The table gets written out to the 'output' file.

    Args:
      table_: An instance of Table.
      output: A file object you can write to.
      output_format: A string, the format to write the table to,
        either 'csv', 'txt' or 'html'.
      css_id: A string, an optional CSS id for the table object (only used for HTML).
      css_class: A string, an optional CSS class for the table object (only used for HTML).
    """
    if output_format in ("txt", "text"):
        text = table_to_text(table_, "  ", formats={"*": ">", "account": "<"})
        output.write(text)

    elif output_format in ("csv",):
        table_to_csv(table_, file=output)

    elif output_format in ("htmldiv", "html"):
        if output_format == "html":
            output.write("<html>\n")
            output.write("<body>\n")

        output.write('<div id="{}">\n'.format(css_id) if css_id else "<div>\n")
        classes = [css_class] if css_class else None
        table_to_html(table_, file=output, classes=classes)
        output.write("</div>\n")

        if output_format == "html":
            output.write("</body>\n")
            output.write("</html>\n")

    else:
        raise NotImplementedError("Unsupported format: {}".format(output_format))

beancount.utils.table.table_to_csv(table, file=None, **kwargs)

Render a Table to a CSV file.

Parameters:
  • table – An instance of a Table.

  • file – A file object to write to. If no object is provided, this function returns a string.

  • **kwargs – Optional arguments forwarded to csv.writer().

Returns:
  • A string, the rendered table, or None, if a file object is provided to write to.

Source code in beancount/utils/table.py
def table_to_csv(table, file=None, **kwargs):
    """Render a Table to a CSV file.

    Args:
      table: An instance of a Table.
      file: A file object to write to. If no object is provided, this
        function returns a string.
      **kwargs: Optional arguments forwarded to csv.writer().
    Returns:
      A string, the rendered table, or None, if a file object is provided
      to write to.
    """
    output_file = file or io.StringIO()

    writer = csv.writer(output_file, **kwargs)
    if table.header:
        writer.writerow(table.header)
    writer.writerows(table.body)

    if not file:
        return output_file.getvalue()

beancount.utils.table.table_to_html(table, classes=None, file=None)

Render a Table to HTML.

Parameters:
  • table – An instance of a Table.

  • classes – A list of string, CSS classes to set on the table.

  • file – A file object to write to. If no object is provided, this function returns a string.

Returns:
  • A string, the rendered table, or None, if a file object is provided to write to.

Source code in beancount/utils/table.py
def table_to_html(table, classes=None, file=None):
    """Render a Table to HTML.

    Args:
      table: An instance of a Table.
      classes: A list of string, CSS classes to set on the table.
      file: A file object to write to. If no object is provided, this
        function returns a string.
    Returns:
      A string, the rendered table, or None, if a file object is provided
      to write to.
    """
    # Initialize file.
    oss = io.StringIO() if file is None else file
    oss.write('<table class="{}">\n'.format(" ".join(classes or [])))

    # Render header.
    if table.header:
        oss.write("  <thead>\n")
        oss.write("    <tr>\n")
        for header in table.header:
            oss.write("      <th>{}</th>\n".format(header))
        oss.write("    </tr>\n")
        oss.write("  </thead>\n")

    # Render body.
    oss.write("  <tbody>\n")
    for row in table.body:
        oss.write("    <tr>\n")
        for cell in row:
            oss.write("      <td>{}</td>\n".format(cell))
        oss.write("    </tr>\n")
    oss.write("  </tbody>\n")

    # Render footer.
    oss.write("</table>\n")
    if file is None:
        return oss.getvalue()

beancount.utils.table.table_to_text(table, column_interspace=' ', formats=None)

Render a Table to ASCII text.

Parameters:
  • table – An instance of a Table.

  • column_interspace – A string to render between the columns as spacer.

  • formats – An optional dict of column name to a format character that gets inserted in a format string specified, like this (where '<char>' is): {:<char><width>}. A key of '' will provide a default value, like this, for example: (... formats={'': '>'}).

Returns:
  • A string, the rendered text table.

Source code in beancount/utils/table.py
def table_to_text(table, column_interspace=" ", formats=None):
    """Render a Table to ASCII text.

    Args:
      table: An instance of a Table.
      column_interspace: A string to render between the columns as spacer.
      formats: An optional dict of column name to a format character that gets
        inserted in a format string specified, like this (where '<char>' is):
        {:<char><width>}. A key of '*' will provide a default value, like
        this, for example: (... formats={'*': '>'}).
    Returns:
      A string, the rendered text table.
    """
    column_widths = compute_table_widths(itertools.chain([table.header], table.body))

    # Insert column format chars and compute line formatting string.
    column_formats = []
    if formats:
        default_format = formats.get("*", None)
    for column, width in zip(table.columns, column_widths):
        if column and formats:
            format_ = formats.get(column, default_format)
            if format_:
                column_formats.append("{{:{}{:d}}}".format(format_, width))
            else:
                column_formats.append("{{:{:d}}}".format(width))
        else:
            column_formats.append("{{:{:d}}}".format(width))

    line_format = column_interspace.join(column_formats) + "\n"
    separator = line_format.format(*[("-" * width) for width in column_widths])

    # Render the header.
    oss = io.StringIO()
    if table.header:
        oss.write(line_format.format(*table.header))

    # Render the body.
    oss.write(separator)
    for row in table.body:
        oss.write(line_format.format(*row))
    oss.write(separator)

    return oss.getvalue()

beancount.utils.test_utils

Support utilities for testing scripts.

beancount.utils.test_utils.ClickTestCase (TestCase)

Base class for command-line program test cases.

beancount.utils.test_utils.RCall (tuple)

RCall(args, kwargs, return_value)

beancount.utils.test_utils.RCall.__getnewargs__(self) special

Return self as a plain tuple. Used by copy and pickle.

Source code in beancount/utils/test_utils.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)

beancount.utils.test_utils.RCall.__new__(_cls, args, kwargs, return_value) special staticmethod

Create new instance of RCall(args, kwargs, return_value)

beancount.utils.test_utils.RCall.__repr__(self) special

Return a nicely formatted representation string

Source code in beancount/utils/test_utils.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

beancount.utils.test_utils.TestCase (TestCase)

beancount.utils.test_utils.TestCase.assertLines(self, text1, text2, message=None)

Compare the lines of text1 and text2, ignoring whitespace.

Parameters:
  • text1 – A string, the expected text.

  • text2 – A string, the actual text.

  • message – An optional string message in case the assertion fails.

Exceptions:
  • AssertionError – If the exception fails.

Source code in beancount/utils/test_utils.py
def assertLines(self, text1, text2, message=None):
    """Compare the lines of text1 and text2, ignoring whitespace.

    Args:
      text1: A string, the expected text.
      text2: A string, the actual text.
      message: An optional string message in case the assertion fails.
    Raises:
      AssertionError: If the exception fails.
    """
    clean_text1 = textwrap.dedent(text1.strip())
    clean_text2 = textwrap.dedent(text2.strip())
    lines1 = [line.strip() for line in clean_text1.splitlines()]
    lines2 = [line.strip() for line in clean_text2.splitlines()]

    # Compress all space longer than 4 spaces to exactly 4.
    # This affords us to be even looser.
    lines1 = [re.sub("    [ \t]*", "    ", line) for line in lines1]
    lines2 = [re.sub("    [ \t]*", "    ", line) for line in lines2]
    self.assertEqual(lines1, lines2, message)

beancount.utils.test_utils.TestCase.assertOutput(self, expected_text)

Expect text printed to stdout.

Parameters:
  • expected_text – A string, the text that should have been printed to stdout.

Exceptions:
  • AssertionError – If the text differs.

Source code in beancount/utils/test_utils.py
@contextlib.contextmanager
def assertOutput(self, expected_text):
    """Expect text printed to stdout.

    Args:
      expected_text: A string, the text that should have been printed to stdout.
    Raises:
      AssertionError: If the text differs.
    """
    with capture() as oss:
        yield oss
    self.assertLines(textwrap.dedent(expected_text), oss.getvalue())

beancount.utils.test_utils.TmpFilesTestBase (TestCase)

A test utility base class that creates and cleans up a directory hierarchy. This convenience is useful for testing functions that work on files, such as the documents tests, or the accounts walk.

beancount.utils.test_utils.TmpFilesTestBase.create_file_hierarchy(test_files, subdir='root') staticmethod

A test utility that creates a hierarchy of files.

Parameters:
  • test_files – A list of strings, relative filenames to a temporary root directory. If the filename ends with a '/', we create a directory; otherwise, we create a regular file.

  • subdir – A string, the subdirectory name under the temporary directory location, to create the hierarchy under.

Returns:
  • A pair of strings, the temporary directory, and the subdirectory under that which hosts the root of the tree.

Source code in beancount/utils/test_utils.py
@staticmethod
def create_file_hierarchy(test_files, subdir="root"):
    """A test utility that creates a hierarchy of files.

    Args:
      test_files: A list of strings, relative filenames to a temporary root
        directory. If the filename ends with a '/', we create a directory;
        otherwise, we create a regular file.
      subdir: A string, the subdirectory name under the temporary directory
        location, to create the hierarchy under.
    Returns:
      A pair of strings, the temporary directory, and the subdirectory under
        that which hosts the root of the tree.
    """
    tempdir = tempfile.mkdtemp(prefix="beancount-test-tmpdir.")
    root = path.join(tempdir, subdir)
    for filename in test_files:
        abs_filename = path.join(tempdir, filename)
        if filename.endswith("/"):
            os.makedirs(abs_filename)
        else:
            parent_dir = path.dirname(abs_filename)
            if not path.exists(parent_dir):
                os.makedirs(parent_dir)
            with open(abs_filename, "w", encoding="utf-8"):
                pass
    return tempdir, root

beancount.utils.test_utils.TmpFilesTestBase.setUp(self)

Hook method for setting up the test fixture before exercising it.

Source code in beancount/utils/test_utils.py
def setUp(self):
    self.tempdir, self.root = self.create_file_hierarchy(self.TEST_DOCUMENTS)

beancount.utils.test_utils.TmpFilesTestBase.tearDown(self)

Hook method for deconstructing the test fixture after testing it.

Source code in beancount/utils/test_utils.py
def tearDown(self):
    shutil.rmtree(self.tempdir, ignore_errors=True)

beancount.utils.test_utils.capture(*attributes)

A context manager that captures what's printed to stdout.

Parameters:
  • *attributes – A tuple of strings, the name of the sys attributes to override with StringIO instances.

Yields: A StringIO string accumulator.

Source code in beancount/utils/test_utils.py
def capture(*attributes):
    """A context manager that captures what's printed to stdout.

    Args:
      *attributes: A tuple of strings, the name of the sys attributes to override
        with StringIO instances.
    Yields:
      A StringIO string accumulator.
    """
    if not attributes:
        attributes = "stdout"
    elif len(attributes) == 1:
        attributes = attributes[0]
    return patch(sys, attributes, io.StringIO)

beancount.utils.test_utils.create_temporary_files(root, contents_map)

Create a number of temporary files under 'root'.

This routine is used to initialize the contents of multiple files under a temporary directory.

Parameters:
  • root – A string, the name of the directory under which to create the files.

  • contents_map – A dict of relative filenames to their contents. The content strings will be automatically dedented for convenience. In addition, the string 'ROOT' in the contents will be automatically replaced by the root directory name.

Source code in beancount/utils/test_utils.py
def create_temporary_files(root, contents_map):
    """Create a number of temporary files under 'root'.

    This routine is used to initialize the contents of multiple files under a
    temporary directory.

    Args:
      root: A string, the name of the directory under which to create the files.
      contents_map: A dict of relative filenames to their contents. The content
        strings will be automatically dedented for convenience. In addition, the
        string 'ROOT' in the contents will be automatically replaced by the root
        directory name.
    """
    os.makedirs(root, exist_ok=True)
    for relative_filename, contents in contents_map.items():
        assert not path.isabs(relative_filename)
        filename = path.join(root, relative_filename)
        os.makedirs(path.dirname(filename), exist_ok=True)

        clean_contents = textwrap.dedent(
            contents.replace("{root}", root.replace("\\", r"\\"))
        )
        with open(filename, "w", encoding="utf-8") as f:
            f.write(clean_contents)

beancount.utils.test_utils.docfile(function, contents=None, prefix='', suffix='.beancount', encoding='utf-8')

A decorator that write the function's docstring to a temporary file and calls the decorated function with the temporary filename. This is useful for writing tests.

Parameters:
  • function – A function to decorate.

  • contents (str | None) – file content, default to function.doc

  • prefix (str) – prefix of filename

  • suffix (str) – suffix of filename

  • encoding (str) – encoding of file content

Returns:
  • The decorated function.

Source code in beancount/utils/test_utils.py
def docfile(
    function,
    contents: str | None = None,
    prefix: str = "",
    suffix: str = ".beancount",
    encoding: str = "utf-8",
):
    """A decorator that write the function's docstring to a temporary file
    and calls the decorated function with the temporary filename.  This is
    useful for writing tests.

    Args:
      function: A function to decorate.
      contents: file content, default to function.__doc__
      prefix: prefix of filename
      suffix: suffix of filename
      encoding: encoding of file content
    Returns:
      The decorated function.
    """

    @functools.wraps(function)
    def new_function(self):
        with temp_file(suffix=suffix, prefix=prefix) as file:
            file.write_text(
                textwrap.dedent(contents or function.__doc__), encoding=encoding
            )
            return function(self, str(file))

    new_function.__doc__ = None
    return new_function

beancount.utils.test_utils.docfile_extra(**kwargs)

A decorator identical to @docfile, but it also takes kwargs for the temporary file, Kwargs: e.g. buffering, encoding, newline, dir, prefix, and suffix.

Returns:
  • docfile

Source code in beancount/utils/test_utils.py
def docfile_extra(**kwargs):
    """
    A decorator identical to @docfile,
    but it also takes kwargs for the temporary file,
    Kwargs:
      e.g. buffering, encoding, newline, dir, prefix, and suffix.
    Returns:
      docfile
    """
    return functools.partial(docfile, **kwargs)

beancount.utils.test_utils.environ(varname, newvalue)

A context manager which pushes varname's value and restores it later.

Parameters:
  • varname – A string, the environ variable name.

  • newvalue – A string, the desired value.

Source code in beancount/utils/test_utils.py
@contextlib.contextmanager
def environ(varname, newvalue):
    """A context manager which pushes varname's value and restores it later.

    Args:
      varname: A string, the environ variable name.
      newvalue: A string, the desired value.
    """
    oldvalue = os.environ.get(varname, None)
    os.environ[varname] = newvalue
    yield
    if oldvalue is not None:
        os.environ[varname] = oldvalue
    else:
        del os.environ[varname]

beancount.utils.test_utils.find_python_lib()

Return the path to the root of the Python libraries.

Returns:
  • A string, the root directory.

Source code in beancount/utils/test_utils.py
def find_python_lib():
    """Return the path to the root of the Python libraries.

    Returns:
      A string, the root directory.
    """
    return path.dirname(path.dirname(path.dirname(__file__)))

beancount.utils.test_utils.find_repository_root(filename=None)

Return the path to the repository root.

Parameters:
  • filename – A string, the name of a file within the repository.

Returns:
  • A string, the root directory.

Source code in beancount/utils/test_utils.py
def find_repository_root(filename=None):
    """Return the path to the repository root.

    Args:
      filename: A string, the name of a file within the repository.
    Returns:
      A string, the root directory.
    """
    if filename is None:
        filename = __file__

    # Support root directory under Bazel.
    match = re.match(r"(.*\.runfiles/beancount)/", filename)
    if match:
        return match.group(1)

    while not path.exists(path.join(filename, "pyproject.toml")):
        prev_filename = filename
        filename = path.dirname(filename)
        if prev_filename == filename:
            raise ValueError("Failed to find the root directory.")
    return filename

beancount.utils.test_utils.make_failing_importer(*removed_module_names)

Make an importer that raise an ImportError for some modules.

Use it like this:

@mock.patch('builtins.import', make_failing_importer('setuptools')) def test_...

Parameters:
  • removed_module_name – The name of the module import that should raise an exception.

Returns:
  • A decorated test decorator.

Source code in beancount/utils/test_utils.py
def make_failing_importer(*removed_module_names):
    """Make an importer that raise an ImportError for some modules.

    Use it like this:

      @mock.patch('builtins.__import__', make_failing_importer('setuptools'))
      def test_...

    Args:
      removed_module_name: The name of the module import that should raise an exception.
    Returns:
      A decorated test decorator.
    """

    def failing_import(name, *args, **kwargs):
        if name in removed_module_names:
            raise ImportError("Could not import {}".format(name))
        return builtins.__import__(name, *args, **kwargs)

    return failing_import

beancount.utils.test_utils.nottest(func)

Make the given function not testable.

Source code in beancount/utils/test_utils.py
def nottest(func):
    "Make the given function not testable."
    func.__test__ = False
    return func

beancount.utils.test_utils.patch(obj, attributes, replacement_type)

A context manager that temporarily patches an object's attributes.

All attributes in 'attributes' are saved and replaced by new instances of type 'replacement_type'.

Parameters:
  • obj – The object to patch up.

  • attributes – A string or a sequence of strings, the names of attributes to replace.

  • replacement_type – A callable to build replacement objects.

Yields: An instance of a list of sequences of 'replacement_type'.

Source code in beancount/utils/test_utils.py
@contextlib.contextmanager
def patch(obj, attributes, replacement_type):
    """A context manager that temporarily patches an object's attributes.

    All attributes in 'attributes' are saved and replaced by new instances
    of type 'replacement_type'.

    Args:
      obj: The object to patch up.
      attributes: A string or a sequence of strings, the names of attributes to replace.
      replacement_type: A callable to build replacement objects.
    Yields:
      An instance of a list of sequences of 'replacement_type'.
    """
    single = isinstance(attributes, str)
    if single:
        attributes = [attributes]

    saved = []
    replacements = []
    for attribute in attributes:
        replacement = replacement_type()
        replacements.append(replacement)
        saved.append(getattr(obj, attribute))
        setattr(obj, attribute, replacement)

    yield replacements[0] if single else replacements

    for attribute, saved_attr in zip(attributes, saved):
        setattr(obj, attribute, saved_attr)

beancount.utils.test_utils.record(fun)

Decorates the function to intercept and record all calls and return values.

Parameters:
  • fun – A callable to be decorated.

Returns:
  • A wrapper function with a .calls attribute, a list of RCall instances.

Source code in beancount/utils/test_utils.py
def record(fun):
    """Decorates the function to intercept and record all calls and return values.

    Args:
      fun: A callable to be decorated.
    Returns:
      A wrapper function with a .calls attribute, a list of RCall instances.
    """

    @functools.wraps(fun)
    def wrapped(*args, **kw):
        return_value = fun(*args, **kw)
        wrapped.calls.append(RCall(args, kw, return_value))
        return return_value

    wrapped.calls = []
    return wrapped

beancount.utils.test_utils.search_words(words, line)

Search for a sequence of words in a line.

Parameters:
  • words – A list of strings, the words to look for, or a space-separated string.

  • line – A string, the line to search into.

Returns:
  • A MatchObject, or None.

Source code in beancount/utils/test_utils.py
def search_words(words, line):
    """Search for a sequence of words in a line.

    Args:
      words: A list of strings, the words to look for, or a space-separated string.
      line: A string, the line to search into.
    Returns:
      A MatchObject, or None.
    """
    if isinstance(words, str):
        words = words.split()
    return re.search(".*".join(r"\b{}\b".format(word) for word in words), line)

beancount.utils.test_utils.skipIfRaises(*exc_types)

A context manager (or decorator) that skips a test if an exception is raised.

Yields: Nothing, for you to execute the function code.

Exceptions:
  • SkipTest – if the test raised the expected exception.

Source code in beancount/utils/test_utils.py
@contextlib.contextmanager
def skipIfRaises(*exc_types):
    """A context manager (or decorator) that skips a test if an exception is raised.

    Args:
      exc_type
    Yields:
      Nothing, for you to execute the function code.
    Raises:
      SkipTest: if the test raised the expected exception.
    """
    try:
        yield
    except exc_types as exception:
        raise unittest.SkipTest(exception)

beancount.utils.test_utils.subprocess_env()

Return a dict to use as environment for running subprocesses.

Returns:
  • A string, the root directory.

Source code in beancount/utils/test_utils.py
def subprocess_env():
    """Return a dict to use as environment for running subprocesses.

    Returns:
      A string, the root directory.
    """
    # Ensure we have locations to invoke our Python executable and our
    # runnable binaries in the test environment to run subprocesses.
    binpath = ":".join(
        [
            path.dirname(sys.executable),
            path.join(find_repository_root(__file__), "bin"),
            os.environ.get("PATH", "").strip(":"),
        ]
    ).strip(":")
    return {"PATH": binpath, "PYTHONPATH": find_python_lib()}

beancount.utils.test_utils.temp_file(prefix='', suffix='.txt')

A context manager that return a filepath inside inside a temporary directory and deletes this directory unconditionally once done.

This utils exists because NamedTemporaryFile can't be re-opened on win32.

Yields: A string, the name of the temporary directory created.

Source code in beancount/utils/test_utils.py
@contextlib.contextmanager
def temp_file(prefix: str = "", suffix: str = ".txt") -> Generator[Path, None, None]:
    """A context manager that return a filepath inside inside a temporary directory and
    deletes this directory unconditionally once done.

    This utils exists because `NamedTemporaryFile` can't be re-opened on win32.

    Yields:
      A string, the name of the temporary directory created.
    """
    with tempfile.TemporaryDirectory(prefix="beancount-test-tmpdir.") as p:
        yield Path(p, prefix + "-temp_file-" + suffix)

beancount.utils.test_utils.tempdir(delete=True, **kw)

A context manager that creates a temporary directory and deletes its contents unconditionally once done.

Parameters:
  • delete – A boolean, true if we want to delete the directory after running.

  • **kw – Keyword arguments for mkdtemp.

Yields: A string, the name of the temporary directory created.

Source code in beancount/utils/test_utils.py
@contextlib.contextmanager
def tempdir(delete=True, **kw):
    """A context manager that creates a temporary directory and deletes its
    contents unconditionally once done.

    Args:
      delete: A boolean, true if we want to delete the directory after running.
      **kw: Keyword arguments for mkdtemp.
    Yields:
      A string, the name of the temporary directory created.
    """
    tempdir = tempfile.mkdtemp(prefix="beancount-test-tmpdir.", **kw)
    try:
        yield tempdir
    finally:
        if delete:
            shutil.rmtree(tempdir, ignore_errors=True)