Code for Replication in Python - csdid

Tables and figures generated from the ACA Medicaid expansion dataset. Author: Gabriel Saco

Introduction

Introduction

This appendix reproduces the tables and figures for the JEL-DiD replication using Python and the csdid package. Each section follows the same workflow as the R/Stata reference, with outputs rendered below.

Setup

Setup: Install Package and Helpers

Install the latest csdid package and define helper utilities used below.

# Install latest csdid from GitHub
# !pip install -U git+https://github.com/d2cml-ai/csdid

def weighted_mean(values: pd.Series, weights: pd.Series) -> float:
    values = np.asarray(values, dtype=float)
    weights = np.asarray(weights, dtype=float)
    return float(np.average(values, weights=weights))


def weighted_var(values: pd.Series, weights: pd.Series) -> float:
    values = np.asarray(values, dtype=float)
    weights = np.asarray(weights, dtype=float)
    avg = np.average(values, weights=weights)
    var = np.average((values - avg) ** 2, weights=weights)
    if weights.sum() > 1:
        var *= weights.sum() / (weights.sum() - 1)
    return float(var)


def fit_ols(formula: str, data: pd.DataFrame, weights=None, cluster: str | None = None):
    model = smf.wls(formula, data=data, weights=weights) if weights is not None else smf.ols(formula, data=data)
    result = model.fit()
    if cluster is not None:
        result = result.get_robustcov_results(cov_type="cluster", groups=data[cluster])
    return result


def fit_logit(formula: str, data: pd.DataFrame, weights=None, cluster: str | None = None):
    model = smf.glm(formula, data=data, family=sm.families.Binomial(), freq_weights=weights)
    if cluster is not None:
        return model.fit(cov_type="cluster", cov_kwds={"groups": data[cluster]})
    return model.fit()


def get_param(result, name: str):
    names = list(result.model.exog_names)
    params = dict(zip(names, np.asarray(result.params).tolist()))
    ses = dict(zip(names, np.asarray(result.bse).tolist()))
    return params.get(name), ses.get(name)

  
Table 1

Table 1: Medicaid Expansion Adoption

Roll-out of Medicaid expansion under the ACA across states.

table1_data = df.copy()
table1_data = table1_data[table1_data["state"] != "DC"].copy()

table1_data["adopt"] = np.where(
    table1_data["yaca"].isna(),
    "Non-Expansion",
    np.where(
        table1_data["state"].isin(["DE", "MA", "NY", "VT"]),
        "Pre-2014",
        table1_data["yaca"].astype("Int64").astype(str),
    ),
)

adopts = table1_data[["state", "adopt"]].drop_duplicates().sort_values("state")

states = (
    adopts.groupby("adopt")["state"]
    .apply(lambda s: ", ".join(s))
    .reset_index(name="states")
)
states["state_share"] = adopts.groupby("adopt")["state"].size().values / 50

base_2013 = table1_data[table1_data["year"] == 2013]
total_counties = len(base_2013)
total_pop = base_2013["population_20_64"].sum()

counties_pop = (
    base_2013.groupby("adopt")
    .agg(county_share=("county_code", "size"), pop_share=("population_20_64", "sum"))
    .reset_index()
)
counties_pop["county_share"] = counties_pop["county_share"] / total_counties
counties_pop["pop_share"] = counties_pop["pop_share"] / total_pop

table1 = states.merge(counties_pop, on="adopt", how="left")
order = ["Pre-2014", "2014", "2015", "2016", "2019", "2020", "2021", "2023", "Non-Expansion"]
table1["order"] = table1["adopt"].apply(lambda x: order.index(x) if x in order else len(order))
table1 = table1.sort_values("order").drop(columns=["order"]).reset_index(drop=True)

table1["state_share"] = table1["state_share"].map(lambda x: f"{x:.2f}")
table1["county_share"] = table1["county_share"].map(lambda x: f"{x:.2f}")
table1["pop_share"] = table1["pop_share"].map(lambda x: f"{x:.2f}")

table1 = table1.rename(
    columns={
        "adopt": "Expansion Year",
        "states": "States",
        "state_share": "Share of States",
        "county_share": "Share of Counties",
        "pop_share": "Share of Adults (2013)",
    }
)

table1.to_csv(TABLE_DIR / "table1_adoptions.csv", index=False)
Expansion Year States Share of States Share of Counties Share of Adults (2013)
Pre-2014 DE, MA, NY, VT 0.08 0.03 0.09
2014 AR, AZ, CA, CO, CT, HI, IA, IL, KY, MD, MI, MN, ND, NH, NJ, NM, NV, OH, OR, RI, WA, WV 0.44 0.36 0.45
2015 AK, IN, PA 0.06 0.06 0.06
2016 LA, MT 0.04 0.04 0.02
2019 ME, VA 0.04 0.05 0.03
2020 ID, NE, UT 0.06 0.04 0.02
2021 MO, OK 0.04 0.06 0.03
2023 NC, SD 0.04 0.05 0.03
Non-Expansion AL, FL, GA, KS, MS, SC, TN, TX, WI, WY 0.20 0.31 0.26
Table 2

Table 2: Simple 2x2 DiD

Simple 2x2 DiD using 2013 vs 2014, expansion vs non-expansion.

mydata = df.copy()
mydata = mydata[~mydata["state"].isin(["DC", "DE", "MA", "NY", "VT"])]
mydata = mydata[(mydata["yaca"] == 2014) | (mydata["yaca"].isna()) | (mydata["yaca"] > 2019)]

mydata = mydata.assign(
    perc_white=mydata["population_20_64_white"] / mydata["population_20_64"] * 100,
    perc_hispanic=mydata["population_20_64_hispanic"] / mydata["population_20_64"] * 100,
    perc_female=mydata["population_20_64_female"] / mydata["population_20_64"] * 100,
    unemp_rate=mydata["unemp_rate"] * 100,
    median_income=mydata["median_income"] / 1000,
)

keep_cols = [
    "state",
    "county",
    "county_code",
    "year",
    "population_20_64",
    "yaca",
    "crude_rate_20_64",
] + COVS

mydata = mydata[keep_cols].copy()
mydata = mydata.dropna(subset=COVS + ["crude_rate_20_64", "population_20_64"]).copy()

mydata = mydata.groupby("county_code", group_keys=False).filter(
    lambda g: g["year"].isin([2013, 2014]).sum() == 2
)
mydata = mydata.groupby("county_code", group_keys=False).filter(lambda g: len(g) == 11).reset_index(drop=True)

mydata["Treat"] = np.where((mydata["yaca"] == 2014) & (~mydata["yaca"].isna()), 1, 0)
mydata["Post"] = np.where(mydata["year"] >= 2014, 1, 0)

weights = (
    mydata[mydata["year"] == 2013][["county_code", "population_20_64"]]
    .drop_duplicates()
    .rename(columns={"population_20_64": "set_wt"})
)
mydata = mydata.merge(weights, on="county_code", how="left")

short_data = mydata[mydata["year"].isin([2013, 2014])].copy()
short_data["Post"] = np.where(short_data["year"] == 2014, 1, 0)

T_pre = short_data.loc[(short_data["Treat"] == 1) & (short_data["year"] == 2013), "crude_rate_20_64"].mean()
C_pre = short_data.loc[(short_data["Treat"] == 0) & (short_data["year"] == 2013), "crude_rate_20_64"].mean()
T_post = short_data.loc[(short_data["Treat"] == 1) & (short_data["year"] == 2014), "crude_rate_20_64"].mean()
C_post = short_data.loc[(short_data["Treat"] == 0) & (short_data["year"] == 2014), "crude_rate_20_64"].mean()

T_pre_w = weighted_mean(
    short_data.loc[(short_data["Treat"] == 1) & (short_data["year"] == 2013), "crude_rate_20_64"],
    short_data.loc[(short_data["Treat"] == 1) & (short_data["year"] == 2013), "set_wt"],
)
C_pre_w = weighted_mean(
    short_data.loc[(short_data["Treat"] == 0) & (short_data["year"] == 2013), "crude_rate_20_64"],
    short_data.loc[(short_data["Treat"] == 0) & (short_data["year"] == 2013), "set_wt"],
)
T_post_w = weighted_mean(
    short_data.loc[(short_data["Treat"] == 1) & (short_data["year"] == 2014), "crude_rate_20_64"],
    short_data.loc[(short_data["Treat"] == 1) & (short_data["year"] == 2014), "set_wt"],
)
C_post_w = weighted_mean(
    short_data.loc[(short_data["Treat"] == 0) & (short_data["year"] == 2014), "crude_rate_20_64"],
    short_data.loc[(short_data["Treat"] == 0) & (short_data["year"] == 2014), "set_wt"],
)

table2 = pd.DataFrame(
    [
        [
            "2013",
            f"{T_pre:.1f}",
            f"{C_pre:.1f}",
            f"{(T_pre - C_pre):.1f}",
            f"{T_pre_w:.1f}",
            f"{C_pre_w:.1f}",
            f"{(T_pre_w - C_pre_w):.1f}",
        ],
        [
            "2014",
            f"{T_post:.1f}",
            f"{C_post:.1f}",
            f"{(T_post - C_post):.1f}",
            f"{T_post_w:.1f}",
            f"{C_post_w:.1f}",
            f"{(T_post_w - C_post_w):.1f}",
        ],
        [
            "Trend/DiD",
            f"{(T_post - T_pre):.1f}",
            f"{(C_post - C_pre):.1f}",
            f"{((T_post - T_pre) - (C_post - C_pre)):.1f}",
            f"{(T_post_w - T_pre_w):.1f}",
            f"{(C_post_w - C_pre_w):.1f}",
            f"{((T_post_w - T_pre_w) - (C_post_w - C_pre_w)):.1f}",
        ],
    ],
    columns=["", "Expansion", "No Expansion", "Gap/DiD", "Expansion", "No Expansion", "Gap/DiD"],
)

table2.to_csv(TABLE_DIR / "table2_simple_did.csv", index=False)
Expansion No Expansion Gap/DiD Expansion No Expansion Gap/DiD
2013 419.2 474.0 -54.8 322.7 376.4 -53.7
2014 428.5 483.1 -54.7 326.5 382.7 -56.2
Trend/DiD 9.3 9.1 0.1 3.7 6.3 -2.6
Table 3

Table 3: Regression DiD

Regression-based 2x2 DiD estimates with and without fixed effects.

short_outcome = short_data.pivot(index="county_code", columns="year", values="crude_rate_20_64")
short_outcome = short_outcome.rename_axis(None, axis=1)
short_outcome_diff = (short_outcome[2014] - short_outcome[2013]).rename("diff")

county_base = (
    short_data.groupby("county_code")
    .agg(state=("state", "first"), set_wt=("set_wt", "mean"), Treat=("Treat", "mean"))
    .reset_index()
)
short_data2 = county_base.merge(short_outcome_diff, left_on="county_code", right_index=True, how="left")
short_data2["Post"] = 1

mod1 = fit_ols("crude_rate_20_64 ~ Treat * Post", short_data, cluster="county_code")
mod2 = fit_ols("crude_rate_20_64 ~ Treat:Post + C(county_code) + C(year)", short_data, cluster="county_code")
mod3 = fit_ols("diff ~ Treat", short_data2, cluster="county_code")

mod4 = fit_ols("crude_rate_20_64 ~ Treat * Post", short_data, weights=short_data["set_wt"], cluster="county_code")
mod5 = fit_ols(
    "crude_rate_20_64 ~ Treat:Post + C(county_code) + C(year)",
    short_data,
    weights=short_data["set_wt"],
    cluster="county_code",
)
mod6 = fit_ols("diff ~ Treat", short_data2, weights=short_data2["set_wt"], cluster="county_code")

models = [mod1, mod2, mod3, mod4, mod5, mod6]
model_labels = ["(1)", "(2)", "(3)", "(4)", "(5)", "(6)"]

rows = []
for coef, label in [
    ("Intercept", "Constant"),
    ("Treat", "Medicaid Expansion"),
    ("Post", "Post"),
    ("Treat:Post", "Medicaid Expansion x Post"),
]:
    row = [label]
    for m in models:
        val, se = get_param(m, coef)
        if coef == "Intercept" and val is None:
            val, se = get_param(m, "const")
        if coef == "Treat:Post" and val is None:
            val, se = get_param(m, "Treat")
        row.append(format_estimate_se(val, se, 1))
    rows.append(row)

fe_rows = [
    ["County fixed effects", "No", "Yes", "No", "No", "Yes", "No"],
    ["Year fixed effects", "No", "Yes", "No", "No", "Yes", "No"],
]

table3 = pd.DataFrame(rows, columns=["Variable"] + model_labels)
table3 = pd.concat([table3, pd.DataFrame(fe_rows, columns=table3.columns)], ignore_index=True)

table3.to_csv(TABLE_DIR / "table3_regression_did.csv", index=False)
Variable (1) (2) (3) (4) (5) (6)
Constant 474.0
(4.3)
498.8
(1.8)
9.1
(2.6)
376.4
(7.6)
500.3
(0.8)
6.3
(1.1)
Medicaid Expansion -54.8
(6.3)
0.1
(3.7)
-53.7
(11.5)
-2.6
(1.5)
Post 9.1
(2.6)
6.3
(1.1)
Medicaid Expansion x Post 0.1
(3.7)
0.1
(5.3)
0.1
(3.7)
-2.6
(1.5)
-2.6
(2.1)
-2.6
(1.5)
County fixed effects No Yes No No Yes No
Year fixed effects No Yes No No Yes No
Table 4

Table 4: Covariate Balance

Balance checks for covariates in levels and 2014-2013 differences.

cov_2013 = short_data[short_data["year"] == 2013][["county_code", "Treat", "set_wt"] + COVS].copy()

unw_means = cov_2013.groupby("Treat")[COVS].mean()
unw_vars = cov_2013.groupby("Treat")[COVS].var()

unw_rows = []
for cov in COVS:
    mean0 = float(unw_means.loc[0, cov])
    mean1 = float(unw_means.loc[1, cov])
    var0 = float(unw_vars.loc[0, cov])
    var1 = float(unw_vars.loc[1, cov])
    norm_diff = (mean1 - mean0) / np.sqrt((var1 + var0) / 2)
    unw_rows.append([cov, mean0, mean1, norm_diff])

wtd_rows = []
for cov in COVS:
    g0 = cov_2013[cov_2013["Treat"] == 0]
    g1 = cov_2013[cov_2013["Treat"] == 1]
    mean0 = weighted_mean(g0[cov], g0["set_wt"])
    mean1 = weighted_mean(g1[cov], g1["set_wt"])
    var0 = weighted_var(g0[cov], g0["set_wt"])
    var1 = weighted_var(g1[cov], g1["set_wt"])
    norm_diff = (mean1 - mean0) / np.sqrt((var1 + var0) / 2)
    wtd_rows.append([mean0, mean1, norm_diff])

top_panel = pd.DataFrame(unw_rows, columns=["variable", "Non-Adopt", "Adopt", "Norm. Diff."])
for idx, (mean0, mean1, norm_diff) in enumerate(wtd_rows):
    top_panel.loc[idx, "Non-Adopt (wtd)"] = mean0
    top_panel.loc[idx, "Adopt (wtd)"] = mean1
    top_panel.loc[idx, "Norm. Diff. (wtd)"] = norm_diff

cov_wide = short_data.set_index(["county_code", "year"])[COVS].unstack()
cov_wide_2013 = cov_wide.xs(2013, level=1, axis=1)
cov_wide_2014 = cov_wide.xs(2014, level=1, axis=1)

cov_diff = (cov_wide_2014 - cov_wide_2013).reset_index()
cov_diff = cov_diff.merge(
    short_data[short_data["year"] == 2013][["county_code", "Treat", "set_wt"]],
    on="county_code",
    how="left",
)

unw_means_diff = cov_diff.groupby("Treat")[COVS].mean()
unw_vars_diff = cov_diff.groupby("Treat")[COVS].var()

unw_rows_diff = []
for cov in COVS:
    mean0 = float(unw_means_diff.loc[0, cov])
    mean1 = float(unw_means_diff.loc[1, cov])
    var0 = float(unw_vars_diff.loc[0, cov])
    var1 = float(unw_vars_diff.loc[1, cov])
    norm_diff = (mean1 - mean0) / np.sqrt((var1 + var0) / 2)
    unw_rows_diff.append([cov, mean0, mean1, norm_diff])

wtd_rows_diff = []
for cov in COVS:
    g0 = cov_diff[cov_diff["Treat"] == 0]
    g1 = cov_diff[cov_diff["Treat"] == 1]
    mean0 = weighted_mean(g0[cov], g0["set_wt"])
    mean1 = weighted_mean(g1[cov], g1["set_wt"])
    var0 = weighted_var(g0[cov], g0["set_wt"])
    var1 = weighted_var(g1[cov], g1["set_wt"])
    norm_diff = (mean1 - mean0) / np.sqrt((var1 + var0) / 2)
    wtd_rows_diff.append([mean0, mean1, norm_diff])

bottom_panel = pd.DataFrame(unw_rows_diff, columns=["variable", "Non-Adopt", "Adopt", "Norm. Diff."])
for idx, (mean0, mean1, norm_diff) in enumerate(wtd_rows_diff):
    bottom_panel.loc[idx, "Non-Adopt (wtd)"] = mean0
    bottom_panel.loc[idx, "Adopt (wtd)"] = mean1
    bottom_panel.loc[idx, "Norm. Diff. (wtd)"] = norm_diff

table4 = pd.concat([top_panel, bottom_panel], ignore_index=True)

var_map = {
    "perc_female": "% Female",
    "perc_white": "% White",
    "perc_hispanic": "% Hispanic",
    "unemp_rate": "Unemployment Rate",
    "poverty_rate": "Poverty Rate",
    "median_income": "Median Income",
}

table4["variable"] = table4["variable"].map(var_map)
for col in table4.columns[1:]:
    table4[col] = table4[col].map(lambda x: f"{x:.2f}")

table4 = table4.rename(columns={"variable": "Variable"})

table4.to_csv(TABLE_DIR / "table4_covariate_balance.csv", index=False)
Variable Non-Adopt Adopt Norm. Diff. Non-Adopt (wtd) Adopt (wtd) Norm. Diff. (wtd)
% Female 49.43 49.33 -0.03 50.48 50.07 -0.24
% White 81.64 90.48 0.59 77.91 79.54 0.11
% Hispanic 9.64 8.23 -0.10 17.01 18.86 0.11
Unemployment Rate 7.61 8.01 0.16 7.00 8.01 0.50
Poverty Rate 19.28 16.53 -0.42 17.24 15.29 -0.37
Median Income 43.04 47.97 0.43 49.31 57.86 0.69
% Female -0.02 -0.02 -0.00 0.02 0.01 -0.09
% White -0.21 -0.21 0.01 -0.32 -0.33 -0.04
% Hispanic 0.20 0.21 0.04 0.25 0.33 0.30
Unemployment Rate -1.16 -1.30 -0.21 -1.08 -1.36 -0.55
Poverty Rate -0.55 -0.28 0.14 -0.41 -0.35 0.05
Median Income 0.98 1.11 0.06 1.10 1.74 0.33
Table 5

Table 5: Regression 2x2 DiD with Covariates

Long-difference regressions with covariates in levels or changes.

reg_data_2013 = cov_2013.merge(short_outcome_diff.reset_index(), on="county_code", how="left")
reg_data_2013 = reg_data_2013.rename(columns={"diff": "long_y"})

reg_data_change = cov_diff.merge(short_outcome_diff.reset_index(), on="county_code", how="left")
reg_data_change = reg_data_change.rename(columns={"diff": "long_y"})

mod5_1 = fit_ols("long_y ~ Treat", reg_data_2013, cluster="county_code")
mod5_2 = fit_ols("long_y ~ Treat + " + " + ".join(COVS), reg_data_2013, cluster="county_code")
mod5_3 = fit_ols("long_y ~ Treat + " + " + ".join(COVS), reg_data_change, cluster="county_code")

mod5_4 = fit_ols("long_y ~ Treat", reg_data_2013, weights=reg_data_2013["set_wt"], cluster="county_code")
mod5_5 = fit_ols(
    "long_y ~ Treat + " + " + ".join(COVS),
    reg_data_2013,
    weights=reg_data_2013["set_wt"],
    cluster="county_code",
)
mod5_6 = fit_ols(
    "long_y ~ Treat + " + " + ".join(COVS),
    reg_data_change,
    weights=reg_data_change["set_wt"],
    cluster="county_code",
)

models5 = [mod5_1, mod5_2, mod5_3, mod5_4, mod5_5, mod5_6]
labels5 = ["No Covs", "X 2013", "X Change", "No Covs", "X 2013", "X Change"]

row5 = ["Medicaid Expansion"]
for m in models5:
    val, se = get_param(m, "Treat")
    row5.append(format_estimate_se(val, se, 2))

table5 = pd.DataFrame([row5], columns=["Variable"] + labels5)

table5.to_csv(TABLE_DIR / "table5_did_covariates.csv", index=False)
Variable No Covs X 2013 X Change No Covs X 2013 X Change
Medicaid Expansion 0.12
(3.75)
-2.35
(4.29)
-0.49
(3.83)
-2.56
(1.49)
-2.56
(1.78)
-1.37
(1.62)
Table 6

Table 6: Outcome Regression and Propensity Score Models

Outcome regression and propensity score models for doubly robust DiD.

outcome_unw = fit_ols("long_y ~ " + " + ".join(COVS), reg_data_2013[reg_data_2013["Treat"] == 0], cluster="county_code")
ps_unw = fit_logit("Treat ~ " + " + ".join(COVS), short_data[short_data["year"] == 2013], cluster="county_code")

outcome_w = fit_ols(
    "long_y ~ " + " + ".join(COVS),
    reg_data_2013[reg_data_2013["Treat"] == 0],
    weights=reg_data_2013[reg_data_2013["Treat"] == 0]["set_wt"],
    cluster="county_code",
)
ps_w = fit_logit(
    "Treat ~ " + " + ".join(COVS),
    short_data[short_data["year"] == 2013],
    weights=short_data[short_data["year"] == 2013]["set_wt"],
    cluster="county_code",
)

rows6 = []
for name, label in [
    ("Intercept", "Constant"),
    ("perc_female", "% Female"),
    ("perc_white", "% White"),
    ("perc_hispanic", "% Hispanic"),
    ("unemp_rate", "Unemployment Rate"),
    ("poverty_rate", "Poverty Rate"),
    ("median_income", "Median Income"),
]:
    val1, se1 = get_param(outcome_unw, name)
    val2, se2 = get_param(ps_unw, name)
    val3, se3 = get_param(outcome_w, name)
    val4, se4 = get_param(ps_w, name)
    if name == "Intercept" and val1 is None:
        val1, se1 = get_param(outcome_unw, "const")
    if name == "Intercept" and val2 is None:
        val2, se2 = get_param(ps_unw, "const")
    if name == "Intercept" and val3 is None:
        val3, se3 = get_param(outcome_w, "const")
    if name == "Intercept" and val4 is None:
        val4, se4 = get_param(ps_w, "const")
    rows6.append(
        [
            label,
            format_estimate_se(val1, se1, 2),
            format_estimate_se(val2, se2, 2),
            format_estimate_se(val3, se3, 2),
            format_estimate_se(val4, se4, 2),
        ]
    )

table6 = pd.DataFrame(
    rows6,
    columns=[
        "Variable",
        "Unweighted Regression",
        "Unweighted Propensity",
        "Weighted Regression",
        "Weighted Propensity",
    ],
)

table6.to_csv(TABLE_DIR / "table6_pscore.csv", index=False)
Variable Unweighted Regression Unweighted Propensity Weighted Regression Weighted Propensity
Constant -20.91
(74.56)
-10.00
(1.34)
-4.62
(41.88)
-8.17
(4.39)
% Female 0.04
(0.86)
-0.04
(0.02)
-0.09
(0.65)
-0.19
(0.06)
% White 0.15
(0.25)
0.06
(0.01)
0.20
(0.10)
0.04
(0.01)
% Hispanic -0.08
(0.19)
-0.02
(0.00)
-0.08
(0.06)
-0.02
(0.01)
Unemployment Rate 1.14
(1.99)
0.32
(0.03)
0.88
(0.95)
0.68
(0.08)
Poverty Rate 0.21
(1.13)
0.03
(0.02)
-0.13
(0.46)
0.11
(0.05)
Median Income 0.09
(0.46)
0.08
(0.01)
-0.05
(0.19)
0.15
(0.02)
Table 7

Table 7: Callaway and Sant'Anna (2021) DiD

Outcome regression, IPW, and doubly robust estimates from csdid.

data_cs = short_data.copy()
data_cs["treat_year"] = np.where((data_cs["yaca"] == 2014) & (~data_cs["yaca"].isna()), 2014, 0)
data_cs["county_code"] = data_cs["county_code"].astype(int)

cs_methods = ["reg", "ipw", "dr"]

cs_results_unw = []
for method in cs_methods:
    # csdid: ATTgt estimates ATT(g,t) by cohort and year.
    att = ATTgt(
        yname="crude_rate_20_64",
        tname="year",
        idname="county_code",
        gname="treat_year",
        data=data_cs,
        control_group="nevertreated",
        xformla="~ " + " + ".join(COVS),
        panel=True,
        allow_unbalanced_panel=True,
        weights_name=None,
        cband=True,
        biters=BITERS,
    )
    # csdid: fit ATTgt using the specified estimator.
    att.fit(est_method=method, base_period="universal", bstrap=BOOTSTRAP)
    # csdid: aggte aggregates ATT(g,t) into cohort-level effects.
    att.aggte(typec="group", na_rm=True, biters=BITERS, cband=True)
    agg = att.atte
    agg_df = pd.DataFrame(
        {
            "group": agg["egt"],
            "att": agg["att_egt"],
            "se": np.asarray(agg["se_egt"]).flatten(),
        }
    )
    row = agg_df.loc[agg_df["group"] == 2014].iloc[0]
    cs_results_unw.append((float(row["att"]), float(row["se"])))

cs_results_w = []
for method in cs_methods:
    # csdid: ATTgt estimates ATT(g,t) with population weights.
    att = ATTgt(
        yname="crude_rate_20_64",
        tname="year",
        idname="county_code",
        gname="treat_year",
        data=data_cs,
        control_group="nevertreated",
        xformla="~ " + " + ".join(COVS),
        panel=True,
        allow_unbalanced_panel=True,
        weights_name="set_wt",
        cband=True,
        biters=BITERS,
    )
    # csdid: fit ATTgt using the specified estimator.
    att.fit(est_method=method, base_period="universal", bstrap=BOOTSTRAP)
    # csdid: aggte aggregates ATT(g,t) into cohort-level effects.
    att.aggte(typec="group", na_rm=True, biters=BITERS, cband=True)
    agg = att.atte
    agg_df = pd.DataFrame(
        {
            "group": agg["egt"],
            "att": agg["att_egt"],
            "se": np.asarray(agg["se_egt"]).flatten(),
        }
    )
    row = agg_df.loc[agg_df["group"] == 2014].iloc[0]
    cs_results_w.append((float(row["att"]), float(row["se"])))

row_att = ["Medicaid Expansion"]
row_se = [""]

for att_val, se_val in cs_results_unw:
    row_att.append(f"{att_val:.2f}")
    row_se.append(f"({se_val:.2f})")
for att_val, se_val in cs_results_w:
    row_att.append(f"{att_val:.2f}")
    row_se.append(f"({se_val:.2f})")

table7 = pd.DataFrame(
    [row_att, row_se],
    columns=["", "Regression", "IPW", "Doubly Robust", "Regression", "IPW", "Doubly Robust"],
)

table7.to_csv(TABLE_DIR / "table7_csdid.csv", index=False)
Regression IPW Doubly Robust Regression IPW Doubly Robust
Medicaid Expansion -1.62 -0.86 -1.23 -3.46 -3.84 -3.76
(4.69) (4.68) (4.93) (2.42) (3.32) (3.20)
Figure 1

Figure 1: Distribution of Propensity Scores

Histogram of propensity scores for expansion vs non-expansion counties.

plot_base = short_data[short_data["year"] == 2013].copy()
plot_base["pscore_unw"] = ps_unw.predict(plot_base)
plot_base["pscore_w"] = ps_w.predict(plot_base)

fig1, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
colors = {1: "#D7263D", 0: "#1B998B"}

for ax, col, title in zip(axes, ["pscore_unw", "pscore_w"], ["Unweighted", "Weighted"]):
    for treat_val, label in [(1, "Expansion Counties"), (0, "Non-Expansion Counties")]:
        subset = plot_base[plot_base["Treat"] == treat_val]
        weights_hist = subset["set_wt"] if col == "pscore_w" else None
        ax.hist(
            subset[col],
            bins=30,
            density=True,
            weights=weights_hist,
            histtype="step",
            linewidth=1.4,
            color=colors[treat_val],
            label=label,
        )
    ax.set_title(title)
    ax.set_xlabel("Propensity Score")
    ax.grid(False)

axes[0].set_ylabel("Density")
axes[0].legend(loc="lower center", bbox_to_anchor=(1.05, -0.35), ncol=2, frameon=False)
fig1.tight_layout()

fig1_path = save_fig(fig1, "figure1_pscore.png")
Figure 1
Figure 2

Figure 2: County Mortality Trends by Expansion Decision

Weighted mortality trends for 2014 expansion vs non-expansion counties.

trend_rows = []
for (treat, year), group in mydata.groupby(["Treat", "year"]):
    trend_rows.append(
        {
            "group": "Expansion Counties" if treat == 1 else "Non-Expansion Counties",
            "year": int(year),
            "mortality": weighted_mean(group["crude_rate_20_64"], group["set_wt"]),
        }
    )

trend = pd.DataFrame(trend_rows)

fig2, ax2 = plt.subplots(figsize=(10, 4.5))
colors2 = {"Expansion Counties": "#D7263D", "Non-Expansion Counties": "#1B998B"}
for group, gdf in trend.groupby("group"):
    ax2.plot(gdf["year"], gdf["mortality"], marker="o", color=colors2[group], label=group)
ax2.axvline(2014, linestyle="--", color="#2E2E2E", linewidth=1)
ax2.set_xlabel("")
ax2.set_ylabel("Mortality (20-64)\nPer 100,000")
ax2.set_xticks(range(2009, 2020))
ax2.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35), ncol=2, frameon=False)
ax2.grid(False)
fig2.tight_layout()

fig2_path = save_fig(fig2, "figure2_trends.png")
Figure 2
Figure 3

Figure 3: 2xT Event Study

Event study without covariates using csdid (regression adjustment).

event_data = mydata.copy()
event_data["treat_year"] = np.where((event_data["yaca"] == 2014) & (~event_data["yaca"].isna()), 2014, 0)
event_data["county_code"] = event_data["county_code"].astype(int)

# csdid: ATTgt estimates ATT(g,t) for event-study dynamics.
event_att = ATTgt(
    yname="crude_rate_20_64",
    tname="year",
    idname="county_code",
    gname="treat_year",
    data=event_data,
    control_group="nevertreated",
    xformla=None,
    panel=True,
    allow_unbalanced_panel=True,
    weights_name="set_wt",
    cband=True,
    biters=BITERS,
)

# csdid: fit ATTgt with outcome regression.
event_att.fit(est_method="reg", base_period="universal", bstrap=BOOTSTRAP)
try:
    # csdid: aggte aggregates ATT(g,t) into event-time effects.
    event_att.aggte(typec="dynamic", biters=BITERS, cband=True)
    agg_event = event_att.atte
except ValueError:
    print("Warning: aggte bootstrap failed; retrying with cband=False.")
    try:
        event_att.aggte(typec="dynamic", biters=BITERS, cband=False)
        agg_event = event_att.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
        event_att.aggte(typec="dynamic", biters=BITERS, cband=False, bstrap=False)
        agg_event = event_att.atte

crit_val = agg_event["crit_val_egt"] if agg_event["crit_val_egt"] is not None else norm.ppf(0.975)

plot_df = pd.DataFrame(
    {
        "event_time": agg_event["egt"],
        "estimate": agg_event["att_egt"],
        "se": np.asarray(agg_event["se_egt"]).flatten(),
        "crit": crit_val,
    }
)
plot_df["conf_low"] = plot_df["estimate"] - plot_df["crit"] * plot_df["se"]
plot_df["conf_high"] = plot_df["estimate"] + plot_df["crit"] * plot_df["se"]

try:
    # csdid: aggte with min_e/max_e for post-period aggregation.
    event_att.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=True)
    agg_post = event_att.atte
except ValueError:
    print("Warning: aggte bootstrap failed; retrying with cband=False.")
    try:
        event_att.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=False)
        agg_post = event_att.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
        event_att.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=False, bstrap=False)
        agg_post = event_att.atte
att_post = float(agg_post["overall_att"])
se_post = float(agg_post["overall_se"])

fig3, ax3 = plt.subplots(figsize=(10, 4.5))
ax3.vlines(plot_df["event_time"], plot_df["conf_low"], plot_df["conf_high"], color="#8A1C1C")
ax3.scatter(plot_df["event_time"], plot_df["estimate"], color="#111111", s=30)
ax3.axhline(0, linestyle="--", color="#2E2E2E", linewidth=1)
ax3.axvline(-1, linestyle="--", color="#2E2E2E", linewidth=1)
ax3.set_xlabel("Event Time")
ax3.set_ylabel("Treatment Effect\nMortality Per 100,000")
ax3.set_xticks(range(-5, 6))
note_text = f"Estimate e in {{0,5}} = {att_post:.2f}\nStd. Error = {se_post:.2f}"
ax3.text(
    0.98,
    0.95,
    note_text,
    transform=ax3.transAxes,
    ha="right",
    va="top",
    fontsize=10,
    bbox=dict(facecolor="white", alpha=0.8, edgecolor="none", pad=3.0),
)
ax3.grid(False)
fig3.tight_layout()

fig3_path = save_fig(fig3, "figure3_event_study.png")
Figure 3
Figure 4

Figure 4: 2xT Event Study with Covariates

Event studies with covariates using RA, IPW, and DR estimators.

methods = [("reg", "Regression"), ("ipw", "IPW"), ("dr", "Doubly Robust")]

fig4, axes4 = plt.subplots(1, 3, figsize=(12, 4.5), sharey=True)

for ax4, (method, label) in zip(axes4, methods):
    # csdid: ATTgt estimates ATT(g,t) with covariates.
    cov_att = ATTgt(
        yname="crude_rate_20_64",
        tname="year",
        idname="county_code",
        gname="treat_year",
        data=event_data,
        control_group="nevertreated",
        xformla="~ " + " + ".join(COVS),
        panel=True,
        allow_unbalanced_panel=True,
        weights_name="set_wt",
        cband=True,
        biters=BITERS,
    )
    # csdid: fit ATTgt with covariates and selected estimator.
    cov_att.fit(est_method=method, base_period="universal", bstrap=BOOTSTRAP)
    try:
        # csdid: aggte aggregates to event-time effects.
        cov_att.aggte(typec="dynamic", biters=BITERS, cband=True)
        agg_cov = cov_att.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with cband=False.")
        try:
            cov_att.aggte(typec="dynamic", biters=BITERS, cband=False)
            agg_cov = cov_att.atte
        except ValueError:
            print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
            cov_att.aggte(typec="dynamic", biters=BITERS, cband=False, bstrap=False)
            agg_cov = cov_att.atte

    crit_val = agg_cov["crit_val_egt"] if agg_cov["crit_val_egt"] is not None else norm.ppf(0.975)

    plot_cov = pd.DataFrame(
        {
            "event_time": agg_cov["egt"],
            "estimate": agg_cov["att_egt"],
            "se": np.asarray(agg_cov["se_egt"]).flatten(),
            "crit": crit_val,
        }
    )
    plot_cov["conf_low"] = plot_cov["estimate"] - plot_cov["crit"] * plot_cov["se"]
    plot_cov["conf_high"] = plot_cov["estimate"] + plot_cov["crit"] * plot_cov["se"]

    ax4.vlines(plot_cov["event_time"], plot_cov["conf_low"], plot_cov["conf_high"], color="#8A1C1C")
    ax4.scatter(plot_cov["event_time"], plot_cov["estimate"], color="#111111", s=24)
    ax4.axhline(0, linestyle="--", color="#2E2E2E", linewidth=1)
    ax4.axvline(-1, linestyle="--", color="#2E2E2E", linewidth=1)
    ax4.set_title(label)
    ax4.set_xticks(range(-5, 6))
    ax4.grid(False)

axes4[0].set_ylabel("Treatment Effect\nMortality Per 100,000")
fig4.tight_layout()

fig4_path = save_fig(fig4, "figure4_event_study_covariates.png")
Figure 4
Figure 5

Figure 5: Mortality Trends by Expansion Timing

Weighted mortality trends by treatment timing group (staggered adoption).

staggered = df.copy()
staggered = staggered[~staggered["state"].isin(["DC", "DE", "MA", "NY", "VT"])].copy()

staggered = staggered.assign(
    perc_white=staggered["population_20_64_white"] / staggered["population_20_64"] * 100,
    perc_hispanic=staggered["population_20_64_hispanic"] / staggered["population_20_64"] * 100,
    perc_female=staggered["population_20_64_female"] / staggered["population_20_64"] * 100,
    unemp_rate=staggered["unemp_rate"] * 100,
    median_income=staggered["median_income"] / 1000,
)

staggered = staggered[
    [
        "state",
        "county",
        "county_code",
        "year",
        "population_20_64",
        "yaca",
        "crude_rate_20_64",
    ]
    + COVS
].copy()

staggered = staggered.dropna(subset=COVS + ["crude_rate_20_64", "population_20_64"]).copy()
staggered = staggered.groupby("county_code", group_keys=False).filter(
    lambda g: g["year"].isin([2013, 2014]).sum() == 2
)
staggered = staggered.groupby("county_code", group_keys=False).filter(lambda g: len(g) == 11).reset_index(drop=True)

staggered["Treat"] = np.where((~staggered["yaca"].isna()) & (staggered["yaca"] <= 2019), 1, 0)
staggered["treat_year"] = np.where((~staggered["yaca"].isna()) & (staggered["yaca"] <= 2019), staggered["yaca"], 0)

staggered_weights = (
    staggered[staggered["year"] == 2013][["county_code", "population_20_64"]]
    .drop_duplicates()
    .rename(columns={"population_20_64": "set_wt"})
)
staggered = staggered.merge(staggered_weights, on="county_code", how="left")

staggered["treat_label"] = staggered["treat_year"].apply(
    lambda x: "Non-Expansion Counties" if x == 0 else str(int(x))
)

trend_rows = []
for (group, year), gdf in staggered.groupby(["treat_label", "year"]):
    trend_rows.append(
        {
            "group": group,
            "year": int(year),
            "mortality": weighted_mean(gdf["crude_rate_20_64"], gdf["set_wt"]),
        }
    )
trend_gxt = pd.DataFrame(trend_rows)

colors5 = {
    "Non-Expansion Counties": "#6F5E76",
    "2014": "#D7263D",
    "2015": "#F4A261",
    "2016": "#2F3D70",
    "2019": "#1B998B",
}

fig5, ax5 = plt.subplots(figsize=(10, 4.5))
for group, gdf in trend_gxt.groupby("group"):
    ax5.plot(gdf["year"], gdf["mortality"], marker="o", color=colors5.get(group, "#444444"), label=group)
ax5.set_xlabel("")
ax5.set_ylabel("Mortality (20-64)\nPer 100,000")
ax5.set_xticks(range(2009, 2020))
ax5.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35), ncol=3, frameon=False)
ax5.grid(False)
fig5.tight_layout()

fig5_path = save_fig(fig5, "figure5_staggered_trends.png")
Figure 5
Figure 6

Figure 6: ATT(g,t) by Calendar Time

Cohort-specific treatment effects over calendar time.

staggered["county_code"] = staggered["county_code"].astype(int)

# csdid: ATTgt estimates ATT(g,t) using not-yet-treated controls.
att_no_x = ATTgt(
    yname="crude_rate_20_64",
    tname="year",
    idname="county_code",
    gname="treat_year",
    data=staggered,
    control_group="notyettreated",
    xformla=None,
    panel=True,
    allow_unbalanced_panel=True,
    weights_name="set_wt",
    cband=True,
    biters=BITERS,
)

# csdid: fit ATTgt with doubly robust estimator.
att_no_x.fit(est_method="dr", base_period="universal", bstrap=BOOTSTRAP)

att_df = pd.DataFrame(
    {
        "group": att_no_x.results["group"],
        "time": att_no_x.results["year"],
        "att": att_no_x.results["att"],
        "se": att_no_x.results["se"],
    }
)

fig6, axes6 = plt.subplots(2, 2, figsize=(10, 7), sharey=True, sharex=True)
axes6 = axes6.ravel()
colors6 = ["#6F5E76", "#D7263D", "#F4A261", "#2F3D70"]

for ax6, group, color in zip(axes6, sorted(att_df["group"].unique()), colors6):
    gdf = att_df[att_df["group"] == group]
    ax6.plot(gdf["time"], gdf["att"], marker="o", color=color)
    ax6.vlines(gdf["time"], gdf["att"] - 1.96 * gdf["se"], gdf["att"] + 1.96 * gdf["se"], color=color)
    ax6.axvline(group - 1, linestyle="--", color="#2E2E2E", linewidth=1)
    ax6.set_title(str(int(group)))
    ax6.grid(False)

axes6[0].set_ylabel("Treatment Effect\nMortality Per 100,000")
axes6[2].set_ylabel("Treatment Effect\nMortality Per 100,000")
for ax6 in axes6[2:]:
    ax6.set_xlabel("Year")

fig6.tight_layout()

fig6_path = save_fig(fig6, "figure6_attgt_calendar.png")
Figure 6
Figure 7

Figure 7: ATT(g,t) in Event Time

ATT(g,t) stacked in relative/event time by cohort.

att_df["event_time"] = att_df["time"] - att_df["group"]

fig7, ax7 = plt.subplots(figsize=(10, 4.5))
colors7 = {2014: "#6F5E76", 2015: "#D7263D", 2016: "#F4A261", 2019: "#2F3D70"}

for group, gdf in att_df.groupby("group"):
    ax7.plot(
        gdf["event_time"],
        gdf["att"],
        marker="o",
        color=colors7.get(group, "#444444"),
        label=str(int(group)),
    )

ax7.axvline(-1, linestyle="--", color="#2E2E2E", linewidth=1)
ax7.set_xlabel("")
ax7.set_ylabel("Treatment Effect\nMortality Per 100,000")
ax7.set_xticks(range(-10, 6))
ax7.legend(loc="lower center", bbox_to_anchor=(0.5, -0.35), ncol=4, frameon=False)
ax7.grid(False)
fig7.tight_layout()

fig7_path = save_fig(fig7, "figure7_attgt_event.png")
Figure 7
Figure 8

Figure 8: GxT Event Study Without Covariates

Event study using not-yet-treated controls, no covariates.

try:
    # csdid: aggte aggregates ATT(g,t) into event-time effects.
    att_no_x.aggte(typec="dynamic", min_e=-5, max_e=5, biters=BITERS, cband=True)
    agg_no_x = att_no_x.atte
except ValueError:
    print("Warning: aggte bootstrap failed; retrying with cband=False.")
    try:
        att_no_x.aggte(typec="dynamic", min_e=-5, max_e=5, biters=BITERS, cband=False)
        agg_no_x = att_no_x.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
        att_no_x.aggte(typec="dynamic", min_e=-5, max_e=5, biters=BITERS, cband=False, bstrap=False)
        agg_no_x = att_no_x.atte
crit_val = agg_no_x["crit_val_egt"] if agg_no_x["crit_val_egt"] is not None else norm.ppf(0.975)

plot_no_x = pd.DataFrame(
    {
        "event_time": agg_no_x["egt"],
        "estimate": agg_no_x["att_egt"],
        "se": np.asarray(agg_no_x["se_egt"]).flatten(),
        "crit": crit_val,
    }
)
plot_no_x["conf_low"] = plot_no_x["estimate"] - plot_no_x["crit"] * plot_no_x["se"]
plot_no_x["conf_high"] = plot_no_x["estimate"] + plot_no_x["crit"] * plot_no_x["se"]

try:
    # csdid: aggte with min_e/max_e for post-period aggregation.
    att_no_x.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=True)
    agg_post = att_no_x.atte
except ValueError:
    print("Warning: aggte bootstrap failed; retrying with cband=False.")
    try:
        att_no_x.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=False)
        agg_post = att_no_x.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
        att_no_x.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=False, bstrap=False)
        agg_post = att_no_x.atte
att_post = float(agg_post["overall_att"])
se_post = float(agg_post["overall_se"])

fig8, ax8 = plt.subplots(figsize=(10, 4.5))
ax8.vlines(plot_no_x["event_time"], plot_no_x["conf_low"], plot_no_x["conf_high"], color="#8A1C1C")
ax8.scatter(plot_no_x["event_time"], plot_no_x["estimate"], color="#111111", s=30)
ax8.axhline(0, linestyle="--", color="#2E2E2E", linewidth=1)
ax8.axvline(-1, linestyle="--", color="#2E2E2E", linewidth=1)
ax8.set_xlabel("Event Time")
ax8.set_ylabel("Treatment Effect\nMortality Per 100,000")
ax8.set_xticks(range(-5, 6))
note_text = f"Estimate e in {{0,5}} = {att_post:.2f}\nStd. Error = {se_post:.2f}"
ax8.text(
    0.98,
    0.95,
    note_text,
    transform=ax8.transAxes,
    ha="right",
    va="top",
    fontsize=10,
    bbox=dict(facecolor="white", alpha=0.8, edgecolor="none", pad=3.0),
)
ax8.grid(False)
fig8.tight_layout()

fig8_path = save_fig(fig8, "figure8_event_no_covs.png")
Figure 8
Figure 9

Figure 9: GxT Event Study With Covariates

Event study with covariates using the doubly robust estimator.

# csdid: ATTgt estimates ATT(g,t) with covariates and not-yet-treated controls.
att_with_x = ATTgt(
    yname="crude_rate_20_64",
    tname="year",
    idname="county_code",
    gname="treat_year",
    data=staggered,
    control_group="notyettreated",
    xformla="~ " + " + ".join(COVS),
    panel=True,
    allow_unbalanced_panel=True,
    weights_name="set_wt",
    cband=True,
    biters=BITERS,
)

# csdid: fit ATTgt with doubly robust estimator and covariates.
att_with_x.fit(est_method="dr", base_period="universal", bstrap=BOOTSTRAP)

try:
    # csdid: aggte aggregates ATT(g,t) into event-time effects.
    att_with_x.aggte(typec="dynamic", min_e=-5, max_e=5, biters=BITERS, cband=True)
    agg_with_x = att_with_x.atte
except ValueError:
    print("Warning: aggte bootstrap failed; retrying with cband=False.")
    try:
        att_with_x.aggte(typec="dynamic", min_e=-5, max_e=5, biters=BITERS, cband=False)
        agg_with_x = att_with_x.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
        att_with_x.aggte(typec="dynamic", min_e=-5, max_e=5, biters=BITERS, cband=False, bstrap=False)
        agg_with_x = att_with_x.atte
crit_val = agg_with_x["crit_val_egt"] if agg_with_x["crit_val_egt"] is not None else norm.ppf(0.975)

plot_with_x = pd.DataFrame(
    {
        "event_time": agg_with_x["egt"],
        "estimate": agg_with_x["att_egt"],
        "se": np.asarray(agg_with_x["se_egt"]).flatten(),
        "crit": crit_val,
    }
)
plot_with_x["conf_low"] = plot_with_x["estimate"] - plot_with_x["crit"] * plot_with_x["se"]
plot_with_x["conf_high"] = plot_with_x["estimate"] + plot_with_x["crit"] * plot_with_x["se"]

try:
    # csdid: aggte with min_e/max_e for post-period aggregation.
    att_with_x.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=True)
    agg_post = att_with_x.atte
except ValueError:
    print("Warning: aggte bootstrap failed; retrying with cband=False.")
    try:
        att_with_x.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=False)
        agg_post = att_with_x.atte
    except ValueError:
        print("Warning: aggte bootstrap failed; retrying with bstrap=False.")
        att_with_x.aggte(typec="dynamic", min_e=0, max_e=5, biters=BITERS, cband=False, bstrap=False)
        agg_post = att_with_x.atte
att_post = float(agg_post["overall_att"])
se_post = float(agg_post["overall_se"])

fig9, ax9 = plt.subplots(figsize=(10, 4.5))
ax9.vlines(plot_with_x["event_time"], plot_with_x["conf_low"], plot_with_x["conf_high"], color="#8A1C1C")
ax9.scatter(plot_with_x["event_time"], plot_with_x["estimate"], color="#111111", s=30)
ax9.axhline(0, linestyle="--", color="#2E2E2E", linewidth=1)
ax9.axvline(-1, linestyle="--", color="#2E2E2E", linewidth=1)
ax9.set_xlabel("Event Time")
ax9.set_ylabel("Treatment Effect\nMortality Per 100,000")
ax9.set_xticks(range(-5, 6))
note_text = f"Estimate e in {{0,5}} = {att_post:.2f}\nStd. Error = {se_post:.2f}"
ax9.text(
    0.98,
    0.95,
    note_text,
    transform=ax9.transAxes,
    ha="right",
    va="top",
    fontsize=10,
    bbox=dict(facecolor="white", alpha=0.8, edgecolor="none", pad=3.0),
)
ax9.grid(False)
fig9.tight_layout()

fig9_path = save_fig(fig9, "figure9_event_covs.png")
Figure 9