4 min read

TIL: __post_init__ for typing Dataclasses

Table of Contents

Snippets of learning.


TL;DR

Don’t use:

@dataclass
class Params:
  """Config object from ingested yaml."""
  file: str

Do use:

@dataclass
class Params:
    file_str: InitVar[str]
    file_path: Path=field(init=False)

    def __post_init__(self, file_str:str):
        self.file_path = Path(file_str)

Motivation

Setting up experiment configs kedro-style means ingesting a bunch of yaml files. Say, to point to an input data file that needs to be loaded:

# example.yaml
inputs:
  file: path/to/mnist.gz

It’s pretty typical to load this yaml file as an un-typed dictionary:

with open("path/to/example.yaml", 'r') as f:
    data = yaml.safe_load(f)

Proceeding around the code with data is pretty straightforward: load_file(data['inputs']['file']), but gets messy quickly with nesting. Even worse when code requires typehinting.

Bring in a Dataclass

@dataclass
class Params:
  """Config object from ingested yaml."""
  file: str

And load everything up:

params = Params(data['inputs']['file'])
# and use somewhere
# ...
data = load_data(params.file)

This is, however, unsatisfying for (1) the brittleness of the dictionary accessing, (2) verbosity of creation, (3) lack of typing control. Consider (3) most strongly: what we really want is a Path to the file input:

@dataclass
class Params:
    file: Path
...
params = Params(data['inputs']['file'])
>>> 
>>> params
Params(file='path/to/mnist.gz')
>>> params.file
'path/to/mnist.gz'
>>> type(params.file)
<class 'str'>

No bueno.

Calling __post_init__ on a Python Dataclass

from dataclasses import dataclass, field, InitVar
from pathlib import Path

@dataclass
class Params:
    file_str: InitVar[str]            # reason for InitVar below
    file_path: Path=field(init=False) # reason for field(...) below

    def __post_init__(self, file_str:str):
        self.file_path = Path(file_str)

Writing the __post_init__ allows taking the yaml-derived string path/to/mnist.gz, and casting it as a Path.

  • InitVar is a pseudo-field, transient, and only passed to __post_init__. It’ll be inaccessible after __post_init__.
  • Path=field(init=False) indicates that a field named file_path will be created and typed as a Path, but it will not be passed during init() (i.e. post-init). If you don’t include init=False, then file_path will be expected as a constructor argument, causing a: TypeError: Params.__init__() missing 1 required positional argument: 'file_path'.

Now you have a nicely typed field:

>>> @dataclass
... class Params:
...     file_str: InitVar[str]
...     file_path: Path=field(init=False)
...     def __post_init__(self, file_str:str):
...         self.file_path = Path(file_str)
...
>>> params = Params(data['inputs']['file'])
>>> params
Params(file_path=PosixPath('path/to/mnist.gz'))

Fin.


Extra: Writing @classmethod from_dict

That data['inputs']['file'] remains an eyesore. Class construction can be handled from an incoming dict in a much sleeker way:

@dataclass
class Params:
    file_str: InitVar[str]            # reason for InitVar below
    file_path: Path=field(init=False) # reason for field(...) below

    def __post_init__(self, file_str:str):
        self.file_path = Path(file_str)

    @classmethod         # explained why below
    def from_dict(cls, data:dict) -> Params:
        return cls(
          file_str = data['inputs']['file']
        )

and calling it is much nicer

params = Params.from_dict(data)
  • @classmethod is required to make from_dict available from the classParams including in an uninstantiated mode. That is, it needs to be callable before ever even making a Params object.
  • There’s nothing special about the name from_dict, it’s just reasonably descriptive.

The Over-Explaining It Part

Given from_dict allows arbitrary assignment, __post_init__ should probably not feel needed at all. All of this could very well be formulated as:

@dataclass
class Params:
    file_path: Path

    @classmethod
    def from_dict(cls, data: dict) -> "Params":
        return cls(file_path=Path(data["inputs"]["file"]))

There are a few reasons to not do this:

  1. It doesn’t show off post-init in this toy problem
  2. It doesn’t keep the logic of transforming or validating fields inside the dataclass itself, independent of where the data comes from (consider constructing Params without using from_dict)
  3. It doesn’t well-support dealing with multiple derived fields or alternate constructors (e.g. parsing a list of paths, normalizing extensions, …)

Closing

This presumes extensive use of @dataclass and a desire for typing. The logical continuation to pydantic is for another post.