Source code for dataframe.match

import fast_causal_inference.lib.tools as ais_tools
from fast_causal_inference.dataframe import *
from fast_causal_inference.dataframe.dataframe import readClickHouse, DataFrame
from fast_causal_inference.dataframe.functions import (
    DfFnColWrapper,
    register_fn,
    define_args,
    FnArg,
    DfFunction,
)
import matplotlib.pyplot as plt
from fast_causal_inference.util import ClickHouseUtils, SqlGateWayConn
import seaborn as sns
from matplotlib import rcParams
from fast_causal_inference import (
    clickhouse_create_view,
    clickhouse_drop_view,
)
from fast_causal_inference.dataframe.df_base import df_2_table


[docs]class CaliperMatching: """ This class implements the Caliper Matching method for causal inference. Parameters ---------- caliper : float, default=0.2 The caliper width for matching. Units are in terms of the standard deviation of the logit of the propensity score. Methods ------- fit(dataframe, treatment, score, exacts=[], alias = 'matching_index'): Apply the Caliper Matching method to the input dataframe. Example ------- .. code-block:: python import fast_causal_inference import fast_causal_inference.dataframe.match as Match df = fast_causal_inference.readClickHouse('test_data_small') model = Match.CaliperMatching(0.5) tmp = model.fit(df, treatment='treatment', score='weight', exacts=['x_cat1']) match_df = tmp.filter("matching_index!=0") # filter out the unmatched records >>> print('sample size Before match: ') >>> df.count().show() >>> print('sample size After match: ') >>> match_df.count().show() sample size Before match: 10000 sample size After match: 9652 >>> import fast_causal_inference.dataframe.match as Match >>> d1 = Match.smd(df, 'treatment', ['x1','x2']) >>> print(d1) Control Treatment SMD x1 -0.012658 -0.023996 -0.011482 x2 0.005631 0.037718 0.016156 >>> import fast_causal_inference.dataframe.match as Match >>> d2 = Match.smd(match_df, 'treatment', ['x1','x2']) >>> print(d2) Control Treatment SMD x1 -0.015521 -0.025225 -0.009821 x2 0.004834 0.039698 0.017551 >>> Match.matching_plot(df_score,'treatment','prob') >>> Match.matching_plot(match_df,'treatment','prob') """ def __init__(self, caliper=0.2): self.caliper = caliper
[docs] def fit(self, dataframe, treatment, score, exacts=[], alias="matching_index"): """ Apply the Caliper Matching method to the input dataframe. Parameters ---------- dataframe : DataFrame The input dataframe. treatment : str The treatment column name. score : str The propensity score column name. exacts : list, default='' The column names for exact matching, ['x_cat1']. alias : str, default='matching_index' The alias for the matching index column in the output dataframe. Returns ------- DataFrame The output dataframe with an additional column for the matching index. """ new_table_name = DataFrame.createTableName() view_df = dataframe.materializedView(is_temp=True) sql = f""" select *, toInt64(0) as {alias} from {view_df.getTableName()} limit 0 """ ClickHouseUtils.clickhouse_create_view_v2( table_name=new_table_name, select_statement=sql, origin_table_name=view_df.getTableName(), is_physical_table=True, ) physical_df = view_df.materializedView( is_physical_table=True, is_distributed_create=False, is_temp=True ) exacts = '+'.join(exacts) if exacts != "": exacts = "," + exacts.replace("+", ",") sql = f""" insert into {new_table_name} with (select CaliperMatchingInfo({treatment}, {score}, {self.caliper}{exacts}) from {physical_df.getTableName()}) as matching_info select *, CaliperMatching(matching_info, {treatment}, {score}, {self.caliper}{exacts}) as {alias} from {physical_df.getTableName()} """ clickhouse_utils = ClickHouseUtils() clickhouse_utils.execute(sql) return readClickHouse(new_table_name)
[docs]def smd(df, T, cols): """ Calculate the Standardized Mean Difference (SMD) for the input dataframe. Parameters ---------- df : DataFrame The input dataframe. T : str The treatment column name. cols : str The column names to calculate the SMD, separated by '+'. Returns ------- DataFrame The output dataframe with the SMD results. Example ------- >>> import fast_causal_inference.dataframe.match as Match >>> d2 = Match.smd(match_df, 'treatment', ['x1','x2']) >>> print(d2) Control Treatment SMD x1 -0.015521 -0.025225 -0.009821 x2 0.004834 0.039698 0.017551 """ new_df = df.materializedView(is_temp=True) pandas_result = ais_tools.SMD(new_df.getTableName(), T, cols) return pandas_result
[docs]def matching_plot( df, T, col, xlim=(0, 1), figsize=(8, 8), xlabel="", ylabel="density", legend=["Control", "Treatment"], ): """This function plots the overlaid distribution of col in df over treat and control group. Parameters ---------- table: str The name of the table to query from. T : str The name of the treatment indicator column in the table. col : str The name of the column that corresponds to the variable to plot. xlim : tuple, optional The tuple of xlim of the plot. (0,1) by default. figsize : tuple, optional The size of the histogram; (8,8) by default. xlabel : str, optional The name of xlabel; col by default. ylabel : str, optional The name of ylabel; `density` by default. legend : iterable, optional The legend; `Control` and `Treatment` by default. Yields ------ An overlaied histogram >>> import fast_causal_inference.dataframe.match as Match >>> Match.matching_plot(df,'treatment','x1') """ table = df.getTableName() sql_instance = SqlGateWayConn.create_default_conn() x1 = sql_instance.sql( f"select {col} from {table} where {T}=1 order by rand() limit 10000" ) x0 = sql_instance.sql( f"select {col} from {table} where {T}=0 order by rand() limit 10000" ) rcParams["figure.figsize"] = figsize[0], figsize[1] ax = sns.distplot(x0) sns.distplot(x1) ax.set_xlim(xlim[0], xlim[1]) if len(xlabel) == 0: ax.set_xlabel(col) else: ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.legend(legend) del x1, x0