Dhruv Sahi
03/26/2024, 11:43 AM# main.py
from typing import Optional, List
import polars as pl
from hamilton import driver, log_setup
import logging
from stages.raw.load_raw_data import RawDataStage
from stages.intermediate import transaction_clean, jurisdiction_clean
from config import get_config, PipelineConfig
class Pipeline:
def __init__(self, _cfg: PipelineConfig) -> None:
self._cfg = _cfg
self.raw_data_stage = RawDataStage(_cfg)
self.dr = driver.Driver(
{},
self.raw_data_stage,
adapter=base.SimplePythonGraphAdapter(base.DictResult)
)
def run(self) -> None:
result = self.dr.execute(final=vars=[self.raw_data_stage])
if __name__ == "__main__":
_cfg = get_config()
pipeline = Pipeline(_cfg=_cfg)
pipeline.run()
My load_raw_data.py looks like this:
from typing import List
import polars as pl
from hamilton.function_modifiers import tag
import logging
from config import PipelineConfig
logger = logging.getLogger(__name__)
INPUT_TABLES = [
"table_1",
"table_2",
]
class RawDataStage:
def __init__(self, config: PipelineConfig) -> None:
self._cfg = config
@staticmethod
def load_parquet(paths: List[str]) -> pl.LazyFrame:
return pl.scan_parquet(paths)
def run(self):
data = {}
for table in INPUT_TABLES:
file_paths = self._cfg.file_paths
data[table] = getattr(self, f"_read_{table}")(paths=file_paths)
return data
@tag(stage="load", input_type="table_1")
def table_1(self, paths=List[str]) -> pl.LazyFrame:
return self.load_parquet(paths=paths)
@tag(stage="load", input_type="table_2")
def table_2(self, paths=List[str]) -> pl.LazyFrame:
return self.load_parquet(paths=paths)
Elijah Ben Izzy
03/26/2024, 12:54 PMpaths: List[str]
rather than paths=List[str]
)
For reasoning, classes have two purposes:
1. They hold state (E.G. _cfg
above)
2. They group functions
Hamilton DAGs are stateless, meaning that everything in the function is a parameter-level input. This encourages functions to be clearer to read/track — we know everything it takes in and thus don’t have to worry about both the way it was instantiated and the way it was called.
I’m not sure what the .run
step does (and where the _read_table
function that is dynamically referred to, but your pipeline should look something like this (not fully tested)
# my_module
def _load_parquet(paths: List[str]) -> pl.LazyFrame:
return pl.scan_parquet(paths)
@tag(stage="load", input_type="table_1")
def table_1(table_1_paths: List[str]) -> p.LazyFrame:
return _load_parquet(paths)
@tag(stage="load", input_type="table_2")
def table_2(table_2_paths: List[str]) -> p.LazyFrame:
return _load_parquet(paths)
...
dr = driver.Builder().with_modules(my_module).build()
results = dr.execute(['table_1', 'table_2'], inputs={'paths_table_1' : ..., 'paths_table_2' : ...})
Dhruv Sahi
03/26/2024, 1:02 PMElijah Ben Izzy
03/26/2024, 1:05 PMStefan Krawczyk
03/26/2024, 4:23 PMDhruv Sahi
03/27/2024, 11:14 AM