import fast_causal_inference.lib.tools as ais_tools
from fast_causal_inference.dataframe import *
from fast_causal_inference.dataframe.dataframe import readClickHouse, DataFrame
from fast_causal_inference.dataframe.functions import (
    DfFnColWrapper,
    register_fn,
    define_args,
    FnArg,
    DfFunction,
)
import matplotlib.pyplot as plt
from fast_causal_inference.util import ClickHouseUtils, SqlGateWayConn
import seaborn as sns
from matplotlib import rcParams
from fast_causal_inference import (
    clickhouse_create_view,
    clickhouse_drop_view,
)
from fast_causal_inference.dataframe.df_base import df_2_table
[docs]class CaliperMatching:
    """
    This class implements the Caliper Matching method for causal inference.
    Parameters
    ----------
    caliper : float, default=0.2
        The caliper width for matching. Units are in terms of the standard deviation of the logit of the propensity score.
    Methods
    -------
    fit(dataframe, treatment, score, exacts=[], alias = 'matching_index'):
        Apply the Caliper Matching method to the input dataframe.
    Example
    -------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.match as Match
        df = fast_causal_inference.readClickHouse('test_data_small')
        model = Match.CaliperMatching(0.5)
        tmp = model.fit(df, treatment='treatment', score='weight', exacts=['x_cat1'])
        match_df = tmp.filter("matching_index!=0") # filter out the unmatched records
    >>> print('sample size Before match: ')
    >>> df.count().show()
    >>> print('sample size After match: ')
    >>> match_df.count().show()
    sample size Before match:
    10000
    sample size After match:
    9652
    >>> import fast_causal_inference.dataframe.match as Match
    >>> d1 = Match.smd(df, 'treatment', ['x1','x2'])
    >>> print(d1)
         Control  Treatment       SMD
    x1 -0.012658  -0.023996 -0.011482
    x2  0.005631   0.037718  0.016156
    >>> import fast_causal_inference.dataframe.match as Match
    >>> d2 = Match.smd(match_df, 'treatment', ['x1','x2'])
    >>> print(d2)
         Control  Treatment       SMD
    x1 -0.015521  -0.025225 -0.009821
    x2  0.004834   0.039698  0.017551
    >>> Match.matching_plot(df_score,'treatment','prob')
    >>> Match.matching_plot(match_df,'treatment','prob')
    """
    def __init__(self, caliper=0.2):
        self.caliper = caliper
[docs]    def fit(self, dataframe, treatment, score, exacts=[], alias="matching_index"):
        """
        Apply the Caliper Matching method to the input dataframe.
        Parameters
        ----------
        dataframe : DataFrame
            The input dataframe.
        treatment : str
            The treatment column name.
        score : str
            The propensity score column name.
        exacts : list, default=''
            The column names for exact matching, ['x_cat1'].
        alias : str, default='matching_index'
            The alias for the matching index column in the output dataframe.
        Returns
        -------
        DataFrame
            The output dataframe with an additional column for the matching index.
        """
        new_table_name = DataFrame.createTableName()
        view_df = dataframe.materializedView(is_temp=True)
        sql = f"""
        select *, toInt64(0) as {alias} from {view_df.getTableName()} limit 0
        """
        ClickHouseUtils.clickhouse_create_view_v2(
            table_name=new_table_name,
            select_statement=sql,
            origin_table_name=view_df.getTableName(),
            is_physical_table=True,
        )
        physical_df = view_df.materializedView(
            is_physical_table=True, is_distributed_create=False, is_temp=True
        )
        exacts = '+'.join(exacts)
        if exacts != "":
            exacts = "," + exacts.replace("+", ",")
        sql = f""" insert into {new_table_name}
        with (select CaliperMatchingInfo({treatment}, {score}, {self.caliper}{exacts}) from {physical_df.getTableName()}) as matching_info 
        select *, CaliperMatching(matching_info, {treatment}, {score}, {self.caliper}{exacts}) as {alias} from {physical_df.getTableName()}
        """
        clickhouse_utils = ClickHouseUtils()
        clickhouse_utils.execute(sql)
        return readClickHouse(new_table_name)  
[docs]def smd(df, T, cols):
    """
    Calculate the Standardized Mean Difference (SMD) for the input dataframe.
    Parameters
    ----------
    df : DataFrame
        The input dataframe.
    T : str
        The treatment column name.
    cols : str
        The column names to calculate the SMD, separated by '+'.
    Returns
    -------
    DataFrame
        The output dataframe with the SMD results.
    Example
    -------
    >>> import fast_causal_inference.dataframe.match as Match
    >>> d2 = Match.smd(match_df, 'treatment', ['x1','x2'])
    >>> print(d2)
         Control  Treatment       SMD
    x1 -0.015521  -0.025225 -0.009821
    x2  0.004834   0.039698  0.017551
    
    """
    new_df = df.materializedView(is_temp=True)
    pandas_result = ais_tools.SMD(new_df.getTableName(), T, cols)
    return pandas_result 
[docs]def matching_plot(
    df,
    T,
    col,
    xlim=(0, 1),
    figsize=(8, 8),
    xlabel="",
    ylabel="density",
    legend=["Control", "Treatment"],
):
    """This function plots the overlaid distribution of col in df over
    treat and control group.
    Parameters
    ----------
    table: str
        The name of the table to query from.
    T : str
        The name of the treatment indicator column in the table.
    col : str
        The name of the column that corresponds to the variable to plot.
    xlim : tuple, optional
        The tuple of xlim of the plot. (0,1) by default.
    figsize : tuple, optional
        The size of the histogram; (8,8) by default.
    xlabel : str, optional
        The name of xlabel; col by default.
    ylabel : str, optional
        The name of ylabel; `density` by default.
    legend : iterable, optional
        The legend; `Control` and  `Treatment` by default.
    Yields
    ------
    An overlaied histogram
    >>> import fast_causal_inference.dataframe.match as Match
    >>> Match.matching_plot(df,'treatment','x1')
    """
    table = df.getTableName()
    sql_instance = SqlGateWayConn.create_default_conn()
    x1 = sql_instance.sql(
        f"select {col} from {table} where {T}=1 order by rand() limit 10000"
    )
    x0 = sql_instance.sql(
        f"select {col} from {table} where {T}=0 order by rand() limit 10000"
    )
    rcParams["figure.figsize"] = figsize[0], figsize[1]
    ax = sns.distplot(x0)
    sns.distplot(x1)
    ax.set_xlim(xlim[0], xlim[1])
    if len(xlabel) == 0:
        ax.set_xlabel(col)
    else:
        ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.legend(legend)
    del x1, x0