Fitting the Model

Use this page after you have prepared X and y for PanelMMM. For input requirements, see Data Preparation.

Basic workflow

fit() is the main entry point for posterior sampling.

import pandas as pd

from abacus.mmm import GeometricAdstock, LogisticSaturation
from abacus.mmm.panel import PanelMMM

dataset = pd.read_csv("data/demo/timeseries/dataset.csv")
dataset["date"] = pd.to_datetime(dataset["date"])

X = dataset.drop(columns=["revenue"])
y = dataset["revenue"].rename("revenue")

mmm = PanelMMM(
    date_column="date",
    target_column="revenue",
    channel_columns=[
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
    ],
    yearly_seasonality=2,
    adstock=GeometricAdstock(l_max=4),
    saturation=LogisticSaturation(),
)

idata = mmm.fit(
    X,
    y,
    draws=500,
    tune=500,
    chains=2,
    cores=2,
    progressbar=False,
    random_seed=42,
)

fit() returns an arviz.InferenceData object and also stores it on mmm.idata.

What fit() does

When you call fit(X, y), Abacus:

  1. checks that pandas X and y use the same index, if both are pandas objects
  2. builds the PyMC graph automatically if it has not been built already
  3. merges sampler settings from the model’s sampler_config and your call-time kwargs
  4. runs pymc.sample(...)
  5. computes deterministic variables and adds them to the posterior group
  6. stores the training data in an InferenceData.fit_data group
  7. writes model metadata into idata.attrs

That means fitted contribution variables such as channel_contribution, intercept_contribution, and yearly_seasonality_contribution are available in mmm.posterior after fitting when they are part of the configured model.

Configure the sampler

You can configure PyMC sampling in two places:

Where Use it for Precedence
sampler_config= in PanelMMM(...) Stable defaults you want to reuse across fits Lower
fit(..., **kwargs) Run-specific overrides such as draws, chains, or random_seed Higher

Abacus merges them so that explicit fit() kwargs win.

mmm = PanelMMM(
    date_column="date",
    target_column="revenue",
    channel_columns=["channel_1", "channel_2"],
    adstock=GeometricAdstock(l_max=4),
    saturation=LogisticSaturation(),
    sampler_config={
        "draws": 1000,
        "tune": 1000,
        "chains": 4,
        "target_accept": 0.9,
        "progressbar": False,
    },
)

# Overrides draws from sampler_config, keeps target_accept
idata = mmm.fit(X, y, draws=500, random_seed=42)

Common sampler arguments

These are passed through to pymc.sample(...).

Argument What it controls
draws Posterior samples kept after tuning
tune Warm-up or adaptation iterations
chains Number of MCMC chains
cores Number of worker processes used by PyMC
target_accept HMC or NUTS acceptance target
progressbar Whether PyMC shows a progress bar
random_seed Sampling reproducibility

If you do not specify progressbar, Abacus defaults it to True unless your sampler_config already sets it.

When to build first

For a standard workflow, call fit() directly.

Call build_model(X, y) first only when you need to inspect or modify the graph before sampling. For example:

mmm.build_model(X, y)
mmm.add_original_scale_contribution_variable(
    var=["channel_contribution", "y"]
)

idata = mmm.fit(
    X,
    y,
    draws=500,
    tune=500,
    chains=2,
    progressbar=False,
    random_seed=42,
)

This pattern is also useful when you need to add events before fitting. Call add_events(...) before build_model(...) or fit(...).

Inspect fitted results

After fitting, common entry points are:

  • mmm.idata
  • mmm.posterior
  • mmm.model
  • mmm.plot
  • mmm.summary
  • mmm.diagnostics

Example:

posterior = mmm.posterior
channel_mean = posterior["channel_contribution"].mean(dim=["chain", "draw"])

Common pitfalls

  • Leaving the target column inside X
  • Passing pandas X and y with different indexes
  • Changing the model graph after fitting and expecting existing samples to stay valid
  • Assuming constructor sampler_config overrides explicit fit() kwargs; it does not
  • Adding events after the model has already been built

Next steps