import fast_causal_inference.lib.tools as ais_tools
from fast_causal_inference.dataframe import *
from fast_causal_inference.dataframe.dataframe import readClickHouse, DataFrame
from fast_causal_inference.dataframe.functions import (
DfFnColWrapper,
register_fn,
define_args,
FnArg,
DfFunction,
)
import matplotlib.pyplot as plt
from fast_causal_inference.util import ClickHouseUtils, SqlGateWayConn
import seaborn as sns
from matplotlib import rcParams
from fast_causal_inference import (
clickhouse_create_view,
clickhouse_drop_view,
)
from fast_causal_inference.dataframe.df_base import df_2_table
[docs]class CaliperMatching:
"""
This class implements the Caliper Matching method for causal inference.
Parameters
----------
caliper : float, default=0.2
The caliper width for matching. Units are in terms of the standard deviation of the logit of the propensity score.
Methods
-------
fit(dataframe, treatment, score, exacts=[], alias = 'matching_index'):
Apply the Caliper Matching method to the input dataframe.
Example
-------
.. code-block:: python
import fast_causal_inference
import fast_causal_inference.dataframe.match as Match
df = fast_causal_inference.readClickHouse('test_data_small')
model = Match.CaliperMatching(0.5)
tmp = model.fit(df, treatment='treatment', score='weight', exacts=['x_cat1'])
match_df = tmp.filter("matching_index!=0") # filter out the unmatched records
>>> print('sample size Before match: ')
>>> df.count().show()
>>> print('sample size After match: ')
>>> match_df.count().show()
sample size Before match:
10000
sample size After match:
9652
>>> import fast_causal_inference.dataframe.match as Match
>>> d1 = Match.smd(df, 'treatment', ['x1','x2'])
>>> print(d1)
Control Treatment SMD
x1 -0.012658 -0.023996 -0.011482
x2 0.005631 0.037718 0.016156
>>> import fast_causal_inference.dataframe.match as Match
>>> d2 = Match.smd(match_df, 'treatment', ['x1','x2'])
>>> print(d2)
Control Treatment SMD
x1 -0.015521 -0.025225 -0.009821
x2 0.004834 0.039698 0.017551
>>> Match.matching_plot(df_score,'treatment','prob')
>>> Match.matching_plot(match_df,'treatment','prob')
"""
def __init__(self, caliper=0.2):
self.caliper = caliper
[docs] def fit(self, dataframe, treatment, score, exacts=[], alias="matching_index"):
"""
Apply the Caliper Matching method to the input dataframe.
Parameters
----------
dataframe : DataFrame
The input dataframe.
treatment : str
The treatment column name.
score : str
The propensity score column name.
exacts : list, default=''
The column names for exact matching, ['x_cat1'].
alias : str, default='matching_index'
The alias for the matching index column in the output dataframe.
Returns
-------
DataFrame
The output dataframe with an additional column for the matching index.
"""
new_table_name = DataFrame.createTableName()
view_df = dataframe.materializedView(is_temp=True)
sql = f"""
select *, toInt64(0) as {alias} from {view_df.getTableName()} limit 0
"""
ClickHouseUtils.clickhouse_create_view_v2(
table_name=new_table_name,
select_statement=sql,
origin_table_name=view_df.getTableName(),
is_physical_table=True,
)
physical_df = view_df.materializedView(
is_physical_table=True, is_distributed_create=False, is_temp=True
)
exacts = '+'.join(exacts)
if exacts != "":
exacts = "," + exacts.replace("+", ",")
sql = f""" insert into {new_table_name}
with (select CaliperMatchingInfo({treatment}, {score}, {self.caliper}{exacts}) from {physical_df.getTableName()}) as matching_info
select *, CaliperMatching(matching_info, {treatment}, {score}, {self.caliper}{exacts}) as {alias} from {physical_df.getTableName()}
"""
clickhouse_utils = ClickHouseUtils()
clickhouse_utils.execute(sql)
return readClickHouse(new_table_name)
[docs]def smd(df, T, cols):
"""
Calculate the Standardized Mean Difference (SMD) for the input dataframe.
Parameters
----------
df : DataFrame
The input dataframe.
T : str
The treatment column name.
cols : str
The column names to calculate the SMD, separated by '+'.
Returns
-------
DataFrame
The output dataframe with the SMD results.
Example
-------
>>> import fast_causal_inference.dataframe.match as Match
>>> d2 = Match.smd(match_df, 'treatment', ['x1','x2'])
>>> print(d2)
Control Treatment SMD
x1 -0.015521 -0.025225 -0.009821
x2 0.004834 0.039698 0.017551
"""
new_df = df.materializedView(is_temp=True)
pandas_result = ais_tools.SMD(new_df.getTableName(), T, cols)
return pandas_result
[docs]def matching_plot(
df,
T,
col,
xlim=(0, 1),
figsize=(8, 8),
xlabel="",
ylabel="density",
legend=["Control", "Treatment"],
):
"""This function plots the overlaid distribution of col in df over
treat and control group.
Parameters
----------
table: str
The name of the table to query from.
T : str
The name of the treatment indicator column in the table.
col : str
The name of the column that corresponds to the variable to plot.
xlim : tuple, optional
The tuple of xlim of the plot. (0,1) by default.
figsize : tuple, optional
The size of the histogram; (8,8) by default.
xlabel : str, optional
The name of xlabel; col by default.
ylabel : str, optional
The name of ylabel; `density` by default.
legend : iterable, optional
The legend; `Control` and `Treatment` by default.
Yields
------
An overlaied histogram
>>> import fast_causal_inference.dataframe.match as Match
>>> Match.matching_plot(df,'treatment','x1')
"""
table = df.getTableName()
sql_instance = SqlGateWayConn.create_default_conn()
x1 = sql_instance.sql(
f"select {col} from {table} where {T}=1 order by rand() limit 10000"
)
x0 = sql_instance.sql(
f"select {col} from {table} where {T}=0 order by rand() limit 10000"
)
rcParams["figure.figsize"] = figsize[0], figsize[1]
ax = sns.distplot(x0)
sns.distplot(x1)
ax.set_xlim(xlim[0], xlim[1])
if len(xlabel) == 0:
ax.set_xlabel(col)
else:
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.legend(legend)
del x1, x0