from fast_causal_inference.dataframe.regression import df_2_table
from typing import List, Dict
from fast_causal_inference.dataframe.functions import (
    DfFnColWrapper,
    register_fn,
    define_args,
    FnArg,
    DfFunction,
    aggregrate,
    OlapEngineType,
    DfContext,
)
from fast_causal_inference.dataframe.df_base import (
    df_2_table,
)
from fast_causal_inference.util import create_sql_instance
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="DeltaMethod")
@register_fn(engine=OlapEngineType.STARROCKS, name="DeltaMethod")
@define_args(
    FnArg(name="expr", is_param=True), FnArg(name="std", default="True", is_param=True)
)
@aggregrate
class AggDeltaMethodDfFunction(DfFunction):
    # @classmethod
    # def _extract_cols_from_expr(cls, expr):
    #     matches = re.findall(r'avg\((.*?)\)', expr)
    #     unique_matches = list(set(matches))
    #     encoded_matches = [(match, f'X{i + 1}') for i, match in enumerate(unique_matches)]
    #     result = expr
    #     for key, value in encoded_matches:
    #         result = result.replace(f'avg({key})', f'avg({value})')
    #     return result, tuple(col for col, _ in encoded_matches)
    def sql_impl_default(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        expr_arg: FnArg = arg_dict["expr"]
        std_arg: FnArg = arg_dict["std"]
        expr = expr_arg.sql(ctx)
        std = std_arg.sql(ctx)
        sql = self.fn_name(ctx) + f"({expr}, {std})"
        return sql
[docs]def delta_method(expr=None, std=True):
    """
    Compute the delta method on the given expression.
    :param expr: Form like f (avg(x1), avg(x2), ...) , f is the complex function expression, x1 and x2 are column names, the columns involved here must be numeric.
    :type expr: str, optional
    :param std: Whether to return standard deviation, default is True.
    :type std: bool, optional
    :return: DataFrame contains the following columns: var or std computed by delta_method.
    :rtype: DataFrame
    Example
    ----------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
        df.groupBy('treatment').delta_method('avg(x1)', False).show()
        df.groupBy('treatment').agg(S.delta_method('avg(x1)')).show()
    This will output:
    .. code-block:: text
        treatment             std
        0         0         1.934587277675054E-4
        1         1        1.9646284055862068E-4
        treatment             var
        0         0        0.013908944164367954
        1         1        0.014016520272828797
    """
    return DfFnColWrapper(AggDeltaMethodDfFunction(), {"expr": expr, "std": std}, []) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="ttest_1samp")
@register_fn(engine=OlapEngineType.STARROCKS, name="ttest_1samp")
@define_args(
    FnArg(name="Y", is_param=True),
    FnArg(name="alternative", default="two-sided", is_param=True),
    FnArg(name="mu", default="0", is_param=True),
    FnArg(name="X", default="", is_param=True),
)
@aggregrate
class AggTTest1SampDfFunction(DfFunction):
    def sql_impl_default(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        Y = arg_dict["Y"].sql(ctx)
        alternative = arg_dict["alternative"].sql(ctx)
        mu = arg_dict["mu"].sql(ctx)
        X = arg_dict["X"].sql(ctx)
        x_str = "" if not X else f", {X}"
        sql = self.fn_name(ctx) + f"({Y}, {alternative}, {mu}{x_str})"
        return sql
[docs]def ttest_1samp(
    Y,
    alternative="two-sided",
    mu=0,
    X="",
):
    """
    This function is used to calculate the t-test for the mean of one group of scores. It returns the calculated t-statistic and the two-tailed p-value.
    :param Y: str, form like f (avg(x1), avg(x2), ...), f is the complex function expression, x1 and x2 are column names, the columns involved here must be numeric.
    :type Y: str, required
    :param alternative: str, use 'two-sided' for two-tailed test, 'greater' for one-tailed test in the positive direction, and 'less' for one-tailed test in the negative direction.
    :type alternative: str, optional
    :param mu: the mean of the null hypothesis.
    :type mu: float, optional
    :param X: str, an expression used as continuous covariates for CUPED variance reduction. It follows the regression approach and can be a simple form like 'avg(x1)/avg(x2)','avg(x3)','avg(x1)/avg(x2)+avg(x3)'.
    :type X: str, optional
    :return: DataFrame contains the following columns:
    estimate: the mean value of the statistic to be tested.
    stderr: the standard error of the statistic to be tested.
    t-statistic: the calculated t-statistic.
    p-value: the calculated p-value.
    lower: the lower bound of the confidence interval.
    upper: the upper bound of the confidence interval.
    Example:
    ----------------
    .. code-block:: python
        import fast_causal_inference.dataframe.statistics as S
        import fast_causal_inference
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.groupBy('x_cat1').ttest_1samp('avg(numerator)/avg(denominator)', alternative = 'two-sided', mu = 0).show()
    >>> df.groupBy('x_cat1').agg(S.ttest_1samp('avg(numerator)', alternative = 'two-sided', mu = 0, X = 'avg(numerator_pre)/avg(denominator_pre)')).show()
    x_cat1  estimate    stderr t-statistic   p-value     lower     upper
    0      B  1.455223  0.041401   35.149887  0.000000  1.374029  1.536417
    1      E  1.753613  0.042083   41.670491  0.000000  1.671082  1.836143
    2      D  1.752348  0.043173   40.589377  0.000000  1.667680  1.837016
    3      C  1.804776  0.046642   38.694122  0.000000  1.713303  1.896249
    4      A  2.108937  0.042558   49.554601  0.000000  2.025477  2.192398
    x_cat1   estimate    stderr t-statistic   p-value      lower      upper
    0      B  10.220695  0.261317   39.112304  0.000000   9.708205  10.733185
    1      E  12.407975  0.267176   46.441156  0.000000  11.884002  12.931947
    2      D  11.924641  0.258935   46.052716  0.000000  11.416831  12.432451
    3      C  12.274732  0.281095   43.667495  0.000000  11.723457  12.826006
    4      A  14.824860  0.241133   61.480129  0.000000  14.351972  15.297748
    """
    return DfFnColWrapper(
        AggTTest1SampDfFunction(),
        {"Y": Y, "alternative": f"'{alternative}'", "mu": mu, "X": X},
        [],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="ttest_2samp")
@register_fn(engine=OlapEngineType.STARROCKS, name="ttest_2samp")
@define_args(
    FnArg(name="Y", is_param=True),
    FnArg(name="alternative", default="two-sided", is_param=True),
    FnArg(name="X", default="", is_param=True),
    FnArg(name="pse", default="", is_param=True),
    FnArg(name="index"),
)
@aggregrate
class AggTTest2SampDfFunction(DfFunction):
    def sql_impl_default(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        Y = arg_dict["Y"].sql(ctx)
        alternative = arg_dict["alternative"].sql(ctx)
        index = arg_dict["index"].sql(ctx)
        X = arg_dict["X"].sql(ctx)
        pse = arg_dict["pse"].sql(ctx)
        x_str = "" if not X else f", {X}"
        x_str = x_str if not pse else f", pse = {pse}"
        sql = self.fn_name(ctx) + f"({Y}, {index}, {alternative}{x_str})"
        return sql
[docs]def ttest_2samp(Y, index, alternative="two-sided", X="", pse=""):
    """
    This function is used to calculate the t-test for the means of two independent samples of scores. It returns the calculated t-statistic and the two-tailed p-value.
    :param Y: str, form like f (avg(x1), avg(x2), ...), f is the complex function expression, x1 and x2 are column names, the columns involved here must be numeric.
    :type Y: str, required
    :param index: str, the treatment variable.
    :type index: str, required
    :param alternative: str, use 'two-sided' for two-tailed test, 'greater' for one-tailed test in the positive direction, and 'less' for one-tailed test in the negative direction.
    :type alternative: str, optional
    :param X: str, an expression used as continuous covariates for CUPED variance reduction. It follows the regression approach and can be a simple form like 'avg(x1)/avg(x2)','avg(x3)','avg(x1)/avg(x2)+avg(x3)'.
    :type X: str, optional
    :param pse: str, an expression used as discrete covariates for post-stratification variance reduction. It involves grouping by a covariate, calculating variances separately, and then weighting them. It can be any complex function form, such as 'x_cat1'.
    :type pse: str, optional
    :return: DataFrame contains the following columns:
    estimate: the mean value of the statistic to be tested.
    stderr: the standard error of the statistic to be tested.
    t-statistic: the calculated t-statistic.
    p-value: the calculated p-value.
    lower: the lower bound of the confidence interval.
    upper: the upper bound of the confidence interval.
    Example:
    ----------------
    .. code-block:: python
        import fast_causal_inference.dataframe.statistics as S
        import fast_causal_inference
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.agg(S.ttest_2samp('avg(numerator)/avg(denominator)', 'treatment', alternative = 'two-sided', pse = 'x_cat1')).show()
    >>> df.agg(S.ttest_2samp('avg(numerator)/avg(denominator)', 'treatment', alternative = 'two-sided', X = 'avg(numerator_pre)/avg(denominator_pre)')).show()
    >>> df.groupBy('x_cat1').ttest_2samp('avg(numerator)', 'treatment', alternative = 'two-sided', X = 'avg(numerator_pre)').show()
    >>> df.groupBy('x_cat1').agg(S.ttest_2samp('avg(numerator)/avg(denominator)', 'treatment', alternative = 'two-sided', X = 'avg(numerator_pre)/avg(denominator_pre)')).show()
    
    .. code-block:: text
            mean0     mean1  estimate    stderr t-statistic   p-value     lower  \
        0  0.791139  2.487152  1.696013  0.032986   51.416725  0.000000  1.631355   
            upper  
        0  1.760672  
            mean0     mean1  estimate    stderr t-statistic   p-value     lower  \
        0  0.793732  2.486118  1.692386  0.026685   63.419925  0.000000  1.640077   
            upper  
        0  1.744694  
        x_cat1     mean0      mean1   estimate    stderr t-statistic   p-value  \
        0      B  2.481226  17.787127  15.305901  0.365716   41.851896  0.000000   
        1      E  4.324137  19.437071  15.112935  0.370127   40.831785  0.000000   
        2      D  4.582961  19.156961  14.574000  0.373465   39.023766  0.000000   
        3      C  4.579375  19.816027  15.236652  0.419183   36.348422  0.000000   
        4      A  7.518409  22.195092  14.676682  0.342147   42.895825  0.000000   
            lower      upper  
        0  14.588665  16.023138  
        1  14.387062  15.838808  
        2  13.841579  15.306421  
        3  14.414564  16.058739  
        4  14.005694  15.347671  
        x_cat1     mean0     mean1  estimate    stderr t-statistic   p-value  \
        0      B  0.409006  2.202847  1.793841  0.053917   33.270683  0.000000   
        1      E  0.714211  2.435665  1.721455  0.056144   30.661265  0.000000   
        2      D  0.781435  2.455767  1.674332  0.058940   28.407344  0.000000   
        3      C  0.778977  2.562364  1.783388  0.065652   27.164280  0.000000   
        4      A  1.242126  2.766098  1.523972  0.060686   25.112311  0.000000   
            lower     upper  
        0  1.688101  1.899581  
        1  1.611348  1.831562  
        2  1.558742  1.789923  
        3  1.654633  1.912142  
        4  1.404959  1.642984  
    """
    return DfFnColWrapper(
        AggTTest2SampDfFunction(),
        {"Y": Y, "alternative": f"'{alternative}'", "X": X, "pse": pse},
        [index],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="xexpt_ttest_2samp")
@register_fn(engine=OlapEngineType.STARROCKS, name="xexpt_ttest_2samp")
@define_args(
    FnArg(name="numerator"),
    FnArg(name="denominator"),
    FnArg(name="index"),
    FnArg(name="uin"),
    FnArg(name="metric_type", default="avg", is_param=True),
    FnArg(name="group_buckets", default="[1,1]", is_param=True),
    FnArg(name="alpha", default="0.05", is_param=True),
    FnArg(name="MDE", default="0.005", is_param=True),
    FnArg(name="power", default="0.8", is_param=True),
    FnArg(name="X", default="", is_param=True),
)
@aggregrate
class AggXexptTTest2SampDfFunction(DfFunction):
    def sql_impl_default(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        numerator = arg_dict["numerator"].sql(ctx)
        denominator = arg_dict["denominator"].sql(ctx)
        index = arg_dict["index"].sql(ctx)
        metric_type = arg_dict["metric_type"].sql(ctx)
        group_buckets = arg_dict["group_buckets"].sql(ctx)
        alpha = arg_dict["alpha"].sql(ctx)
        MDE = arg_dict["MDE"].sql(ctx)
        power = arg_dict["power"].sql(ctx)
        X = arg_dict["X"].sql(ctx)
        uin = arg_dict["uin"].sql(ctx)
        if metric_type == "avg":
            group_buckets = ""
            metric_type = ""
        else:
            group_buckets = "," + group_buckets
            metric_type = ",'" + metric_type + "'"
        if X != "":
            X = "," + X
        sql = (
            self.fn_name(ctx)
            + "("
            + numerator
            + ","
            + denominator
            + ","
            + index
            + ","
            + uin
            + metric_type
            + group_buckets
            + ","
            + str(alpha)
            + ","
            + str(MDE)
            + ","
            + str(power)
            + X
            + ")"
        )
        return sql
[docs]def xexpt_ttest_2samp(
    numerator,
    denominator,
    index,
    uin,
    metric_type="avg",
    group_buckets="[1,1]",
    alpha=0.05,
    MDE=0.005,
    power=0.8,
    X="",
):
    """
    This function is used to calculate the t-test for the means of two independent samples of scores. It returns the calculated t-statistic and the two-tailed p-value.
    :param numerator: column name, the numerator of the metric, can use sql expression, the column must be numeric.
    :type numerator: str, required
    :param denominator: column name, the denominator of the metric, can use sql expression, the column must be numeric.
    :type denominator: str, required
    :param index: column name, used to represent the control group and the experimental group.
    :type index: str, required
    :param uin: column name, used to bucket samples, can use sql expression, int64 type.
    :type uin: str, required
    :param metric_type:
        avg: used to test the mean indicator, avg(num)/avg(demo), default is avg.
        sum: used to test the sum indicator, at this time the denominator can be omitted or 1, otherwise the user is prompted.
    :type metric_type: str, optional
    :param group_buckets: the number of traffic buckets for each group, only effective when metric_type='sum'. The default is [1,1], the number of elements is equal to the number of groups, only the correct ratio is required.
    :type group_buckets: list, optional
    :param alpha: numeric, significance level, default 0.05.
    :type alpha: float, optional
    :param MDE: numeric, minimum test difference, default 0.005.
    :type MDE: float, optional
    :param power: numeric, statistical power, default 0.8.
    :type power: float, optional
    :param X: str, an expression used as continuous covariates for CUPED variance reduction. It follows the regression approach and can be a simple form like 'avg(x1)/avg(x2)','avg(x3)','avg(x1)/avg(x2)+avg(x3)'.
    :type X: str, optional
    :return: DataFrame contains the following columns:
    groupname: the name of the group.
    numerator: the mean of the numerator.
    denominator: the mean of the denominator (only when metric_type=avg).
    numerator_pre: the mean of numerator before the experiment (only when metric_type=avg).
    denominator_pre: the mean of denominator before the experiment (only when metric_type=avg).
    mean: the mean of the metric (only when metric_type=avg).
    std_samp: the standard deviation of the metric (only when metric_type=avg).
    ratio: group_buckets (only when metric_type=sum).
    diff_relative: the relative difference between the two groups.
    95%_relative_CI: the 95% confidence interval of the relative difference.
    p-value: the calculated p-value.
    t-statistic: the calculated t-statistic.
    power: the calculated power.
    recommend_samples: the recommended sample size.
    Example:
    ----------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.xexpt_ttest_2samp('numerator', 'denominator', 'treatment', uin = 'rand()', metric_type = 'sum', group_buckets=[1,1]).show()
    groupname   numerator     ratio
    0           23058.627723  1
    1           100540.303112 1
    diff_relative 95%_relative_CI           p-value     t-statistic power       recommend_samples
    336.020323%   [320.514511%,351.526135%] 0.000000    42.478747   0.050458    24404575
    >>> df.xexpt_ttest_2samp('numerator', 'denominator', 'treatment', uin = 'rand()', metric_type = 'sum', group_buckets=[1,1], X = 'avg(numerator_pre)/avg(denominator_pre)').show()
    groupname   numerator     ratio       numerator_pre
    0           23058.627723  1           21903.112431
    1           100540.303112 1           23096.875608
    diff_relative 95%_relative_CI           p-value     t-statistic power       recommend_samples
    310.412514%   [299.416469%,321.408558%] 0.000000    55.335445   0.050911    12696830
    >>> df.xexpt_ttest_2samp('numerator', 'denominator', 'treatment', uin = 'rand()', metric_type = 'avg', X = 'avg(numerator_pre)/avg(denominator_pre)').show()
    groupname   numerator     denominator  numerator_pre denominator_pre mean        std_samp
    0           23058.627723  29023.233157 21903.112431  29131.831739    0.793678    1.253257
    1           100540.303112 40452.337656 23096.875608  30776.559777    2.486168    5.123161
    diff_relative 95%_relative_CI           p-value     t-statistic diff        95%_CI              power       recommend_samples
    213.246344%   [206.698202%,219.794486%] 0.000000    63.835777   1.692490    [1.640519,1.744461] 0.052570    14172490
    >>> df.agg(S.xexpt_ttest_2samp('numerator', 'denominator', 'treatment', uin = 'rand()', metric_type = 'avg', alpha = 0.05, MDE = 0.005, power = 0.8, X = 'avg(numerator_pre)+avg(x1)')).show()
    groupname   numerator     denominator  numerator_pre denominator_pre mean        std_samp
    0           23058.627723  29023.233157 21903.112431  -62.102593      1.057338    2.341991
    1           100540.303112 40452.337656 23096.875608  -122.234609     2.732950    5.918014
    diff_relative 95%_relative_CI           p-value     t-statistic diff        95%_CI              power       recommend_samples
    158.474659%   [152.453710%,164.495607%] 0.000000    51.593567   1.675612    [1.611950,1.739274] 0.053041    11982290
    """
    return DfFnColWrapper(
        AggXexptTTest2SampDfFunction(),
        {
            "metric_type": metric_type,
            "group_buckets": group_buckets,
            "alpha": alpha,
            "MDE": MDE,
            "power": power,
            "X": X,
        },
        [numerator, denominator, index, uin],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="SRM")
@register_fn(engine=OlapEngineType.STARROCKS, name="srm")
@define_args(
    FnArg(name="x"),
    FnArg(name="groupby"),
    FnArg(name="ratio", default="[1,1]", is_param=True),
)
@aggregrate
class AggSRMDfFunction(DfFunction):
    def sql_impl_default(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        x = arg_dict["x"].sql(ctx)
        groupby = arg_dict["groupby"].sql(ctx)
        ratio = arg_dict["ratio"].sql(ctx)
        sql = self.fn_name(ctx) + f"({x + ',' + groupby + ',' + ratio})"
        return sql
[docs]def srm(x, groupby, ratio="[1,1]"):
    """
    perform srm test
    :param x: column name, the numerator of the metric, can use SQL expression, the column must be numeric.
        If you are concerned about whether the sum of x1 meets expectations, you should fill in x1, then it will calculate sum(x1);
        If you are concerned about whether the sample size meets expectations, you should fill in 1, then it will calculate sum(1).
    :type x: str, required
    :param groupby: column name, representing the field for aggregation grouping, can support Integer/String.
    :type groupby: str, required
    :param ratio: list. The expected traffic ratio, needs to be filled in according to the order of the groupby field. Each value must be >0. For example, [1,1,2] represents the expected ratio is 1:1:2.
    :type ratio: list, required
    :return: DataFrame contains the following columns:
    groupname: the name of the group.
    f_obs: the observed traffic.
    ratio: the expected traffic ratio.
    chisquare: the calculated chi-square.
    p-value: the calculated p-value.
    Example:
    ----------------
    .. code-block:: python
        import fast_causal_inference.dataframe.statistics as S
    >>> df.srm('x1', 'treatment', '[1,2]').show()
            groupname   f_obs       ratio       chisquare   p-value
    0           23058.627723 1.000000    48571.698643 0.000000
    1           1.0054e+05  1.000000
    >>> df.agg(S.srm('x1', 'treatment', '[1,2]')).show()
            groupname   f_obs       ratio       chisquare   p-value
    0           23058.627723 1.000000    48571.698643 0.000000
    1           1.0054e+05  1.000000
    """
    return DfFnColWrapper(AggSRMDfFunction(), {"ratio": ratio}, [x, groupby]) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="mannWhitneyUTest")
@define_args(
    FnArg(name="alternative", is_param=True, default="two-sided"),
    FnArg(name="continuity_correction", is_param=True, default=1),
    FnArg(name="sample_data"),
    FnArg(name="sample_index"),
)
@aggregrate
class AggMannWhitneyUTestDfFunction(DfFunction):
    def sql_impl_clickhouse(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        alternative = f"'{arg_dict['alternative'].sql(ctx)}'"
        continuity_correction = arg_dict["continuity_correction"].sql(ctx)
        sample_data = arg_dict["sample_data"].sql(ctx)
        sample_index = arg_dict["sample_index"].sql(ctx)
        sql = (
            self.fn_name(ctx)
            + f"({alternative}, {continuity_correction})({sample_data}, {sample_index})"
        )
        return sql
[docs]def mann_whitney_utest(
    sample_data, sample_index, alternative="two-sided", continuity_correction=1
):
    """
    This function is used to calculate the Mann-Whitney U test. It returns the calculated U-statistic and the two-tailed p-value.
    :param sample_data: column name, the numerator of the metric, can use SQL expression, the column must be numeric.
    :type sample_data: str, required
    :param sample_index: column name, the index to represent the control group and the experimental group, 1 for the experimental group and 0 for the control group.
    :type sample_index: str, required
    :param alternative:
        'two-sided': the default value, two-sided test.
        'greater': one-tailed test in the positive direction.
        'less': one-tailed test in the negative direction.
    :type alternative: str, optional
    :param continuous_correction: bool, default 1, whether to apply continuity correction.
    :type continuous_correction: bool, optional
    :return: Tuple with two elements:
    U-statistic: Float64.
    p-value: Float64.
    Example:
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.mann_whitney_utest('x1', 'treatment').show()
    [2380940.0, 0.0]
    >>> df.agg(S.mann_whitney_utest('x1', 'treatment')).show()
    [2380940.0, 0.0]
    """
    return DfFnColWrapper(
        AggMannWhitneyUTestDfFunction(),
        {"alternative": alternative, "continuity_correction": continuity_correction},
        [sample_data, sample_index],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="kolmogorovSmirnovTest")
@define_args(FnArg(name="sample_data"), FnArg(name="sample_index"))
@aggregrate
class AggKolmogorovSmirnovTestDfFunction(DfFunction):
    pass
[docs]def kolmogorov_smirnov_test(sample_data, sample_index):
    """
    This function is used to calculate the Kolmogorov-Smirnov test for goodness of fit. It returns the calculated statistic and the two-tailed p-value.
    :param sample_data: Sample data. Integer, Float or Decimal.
    :type sample_data: int, float or decimal, required
    :param sample_index: Sample index. Integer.
    :type sample_index: int, required
    :return: Tuple with two elements:
    calculated statistic: Float64.
    calculated p-value: Float64.
    Example:
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.kolmogorov_smirnov_test('y', 'treatment').show()
    [0.6382961593945475, 0.0]
    >>> df.agg(S.kolmogorov_smirnov_test('y', 'treatment')).show()
    [0.6382961593945475, 0.0]
    """
    return DfFnColWrapper(
        AggKolmogorovSmirnovTestDfFunction(), {}, [sample_data, sample_index]
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="studentTTest")
@define_args(FnArg(name="sample_data"), FnArg(name="sample_index"))
@aggregrate
class AggStudentTTestDfFunction(DfFunction):
    pass
[docs]def student_ttest(sample_data, sample_index):
    """
    This function is used to calculate the t-test for the mean of one group of scores. It returns the calculated t-statistic and the two-tailed p-value.
    :param sample_data: column name, the numerator of the metric, can use sql expression, the column must be numeric
    :type sample_data: str, required
    :param sample_index: column name, the index to represent the control group and the experimental group, 1 for the experimental group and 0 for the control group
    :type sample_index: str, required
    :return: Tuple with two elements:
    calculated statistic: Float64.
    calculated p-value: Float64.
    Example
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.student_ttest('y', 'treatment').show()
    [-72.8602591880598, 0.0]
    >>> df.agg(S.student_ttest('y', 'treatment')).show()
    [-72.8602591880598, 0.0]
    """
    return DfFnColWrapper(AggStudentTTestDfFunction(), {}, [sample_data, sample_index]) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="welchTTest")
@define_args(FnArg(name="sample_data"), FnArg(name="sample_index"))
@aggregrate
class AggWelchTTestDfFunction(DfFunction):
    pass
[docs]def welch_ttest(sample_data, sample_index):
    """
    This function is used to calculate welch's t-test for the mean of two independent samples of scores. It returns the calculated t-statistic and the two-tailed p-value.
    :param sample_data: column name, the numerator of the metric, can use sql expression, the column must be numeric
    :type sample_data: str, required
    :param sample_index: column name, the index to represent the control group and the experimental group, 1 for the experimental group and 0 for the control group
    :type sample_index: str, required
    :return: Tuple with two elements:
    calculated statistic: Float64.
    calculated p-value: Float64.
    Example
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
    >>> df.welch_ttest('y', 'treatment').show()
    [-73.78492246858345, 0.0]
    >>> df.agg(S.welch_ttest('y', 'treatment')).show()
    [-73.78492246858345, 0.0]
    """
    return DfFnColWrapper(AggWelchTTestDfFunction(), {}, [sample_data, sample_index]) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="meanZTest")
@define_args(
    FnArg(name="sample_data"),
    FnArg(name="sample_index"),
    FnArg(name="population_variance_x", is_param=True),
    FnArg(name="population_variance_y", is_param=True),
    FnArg(name="confidence_level", is_param=True),
)
@aggregrate
class AggMeanZTestDfFunction(DfFunction):
    pass
[docs]def mean_z_test(
    sample_data,
    sample_index,
    population_variance_x,
    population_variance_y,
    confidence_level,
):
    """
    This function is used to calculate the z-test for the mean of two independent samples of scores. It returns the calculated z-statistic and the two-tailed p-value.
    :param sample_data: column name, the numerator of the metric, can use sql expression, the column must be numeric
    :type sample_data: str, required
    :param sample_index: column name, the index to represent the control group and the experimental group, 1 for the experimental group and 0 for the control group
    :type sample_index: str, required
    :param population_variance_x: Variance for control group.
    :type population_variance_x: Float, required
    :param population_variance_y: Variance for experimental group.
    :type population_variance_y: Float, required
    :param confidence_level: Confidence level in order to calculate confidence intervals.
    :type confidence_level: Float, required
    :return:
    Example
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
        df.mean_z_test('y', 'treatment', 0.9, 0.9, 0.95).show()
        df.agg(S.mean_z_test('y', 'treatment', 0.9, 0.9, 0.95)).show()
    """
    return DfFnColWrapper(
        AggMeanZTestDfFunction(),
        {
            "population_variance_x": population_variance_x,
            "population_variance_y": population_variance_y,
            "confidence_level": confidence_level,
        },
        [sample_data, sample_index],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="bootStrap")
@define_args(FnArg(name="func"), FnArg(name="sample_num"), FnArg(name="bs_num"))
class BootStrapDfFunction(DfFunction):
    pass
[docs]def boot_strap(func, sample_num, bs_num):
    """
    Compute a two-sided bootstrap confidence interval of a statistic.
    boot_strap sample_num samples from data and compute the func.
    :param func: function to apply.
    :type func: str, required
    :param sample_num: number of samples.
    :type sample_num: int, required
    :param bs_num: number of bootstrap samples.
    :type bs_num: int, required
    :return: list of calculated statistics.
    :type return: Array(Float64).
    Example
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
        df.boot_strap(func='avg(x1)', sample_num=10000, bs_num=5).show()
        df.agg(S.boot_strap(func="ttest_1samp(avg(x1), 'two-sided',0)", sample_num=100000, bs_num=3)). show()
        df.agg(S.boot_strap(func="ttest_2samp(avg(x1), treatment, 'two-sided')", sample_num=100000, bs_num=3)). show()
    """
    return DfFnColWrapper(
        BootStrapDfFunction(),
        {},
        ["'" + func.replace("'", "@") + "'", sample_num, bs_num],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="Permutation")
@define_args(
    FnArg(name="func"),
    FnArg(name="permutation_num"),
    FnArg(name="mde", default=""),
    FnArg(name="mde_type", default=""),
)
class PermutationDfFunction(DfFunction):
    pass
[docs]def permutation(func, permutation_num, mde="0", mde_type="1"):
    """
    :param func: function to apply.
    :type func: str, required
    :param permutation_num: number of permutations.
    :type permutation_num: int, required
    :param col: columns to apply function to.
    :type col: int, float or decimal, required
    :return: list of calculated statistics.
    :type return: Array(Float64).
    Example
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
        df.permutation('mannWhitneyUTest', 3, 'x1')
        df.agg(S.permutation('mannWhitneyUTest', 3, 'x1')).show()
    """
    return DfFnColWrapper(
        PermutationDfFunction(),
        {},
        ["'" + func.replace("'", "@") + "'", permutation_num, mde, mde_type],
    ) 
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="MatrixMultiplication")
@register_fn(engine=OlapEngineType.STARROCKS, name="matrix_multiplication")
@define_args(
    FnArg(name="std", default="False", is_param=True),
    FnArg(name="invert", default="False", is_param=True),
    FnArg(name="col", is_variadic=True),
)
@aggregrate
class AggMatrixMultiplicationDfFunction(DfFunction):
    def sql_impl_clickhouse(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        col = ", ".join(map(lambda x: x.sql(ctx), arg_dict["col"].column))
        std = arg_dict["std"].sql(ctx)
        invert = arg_dict["invert"].sql(ctx)
        sql = self.fn_name(ctx) + f"({std}, {invert})" + f"({col})"
        return sql
    def sql_impl_starrocks(
        self,
        ctx: DfContext,
        fn_args: List[FnArg],
        fn_params: List[FnArg],
        arg_dict: Dict,
    ) -> str:
        col = ", ".join(map(lambda x: x.sql(ctx), arg_dict["col"].column))
        std = arg_dict["std"].sql(ctx)
        invert = arg_dict["invert"].sql(ctx)
        sql = self.fn_name(ctx) + f"([{col}], {std}, {invert})"
        return sql
[docs]def matrix_multiplication(*col, std=False, invert=False):
    """
    :param col: columns to apply function to.
    :type col: int, float or decimal, required
    :param std: whether to return standard deviation.
    :type std: bool, required
    :param invert: whether to invert the matrix.
    :type invert: bool, required
    :return: list of calculated statistics.
    :type return: Array(Float64).
    Example
    ----------------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.statistics as S
        df = fast_causal_inference.readClickHouse('test_data_small')
        df.matrix_multiplication('x1', 'x2', std = False, invert = False).show()
        df.agg(S.matrix_multiplication('x1', 'x2', std = False, invert = False)).show()
        df.agg(S.matrix_multiplication('x1', 'x2', std = True, invert = True)).show()
    """
    return DfFnColWrapper(
        AggMatrixMultiplicationDfFunction(), {"std": std, "invert": invert}, col
    ) 
import scipy.stats as stats
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
import seaborn as sns
from matplotlib import rcParams
import warnings
[docs]def IPWestimator(df, Y, T, P, B=500):
    """
    Estimate the Average Treatment Effect (ATE) using Inverse Probability of Treatment Weighting (IPTW).
    :param table: the name of the input data table.
    :type table: str, required
    :param Y: the column name of the outcome variable.
    :type Y: str, required
    :param T: the column name of the treatment variable.
    :type T: str, required
    :param P: the column name of the propensity score.
    :type P: str, required
    :param B: the number of bootstrap samples, default is 500.
    :type B: int, optional
    :return: dict, containing the following key-value pairs:
    'ATE': Average Treatment Effect.
    'stddev': Standard deviation.
    'p_value': p-value.
    '95% confidence_interval': 95% confidence interval.
    Example
    ----------
    .. code-block:: python
        import fast_causal_inference
        table = 'test_data_small'
        df = fast_causal_inference.readClickHouse(table)
        Y = 'numerator'
        T = 'treatment'
        P = 'weight'
        import fast_causal_inference.dataframe.statistics as S
        S.IPWestimator(df,Y,T,P,B=500)
    """
    table = df_2_table(df)
    # Create SQL instance
    sql_instance = create_sql_instance()
    # Get the number of rows in the table
    n = int(sql_instance.sql(f"select count(*) as cnt from {table}")["cnt"][0])
    # Execute SQL query to calculate IPTW estimates
    res = (
        sql_instance.sql(
            f"""WITH (
      SELECT DistributedNodeRowNumber(1)(0)
      FROM {table}
    ) AS pa
    SELECT
      BootStrapMulti('sum:1;sum:1;sum:1;sum:1',  {n}, {B}, pa)(
      {Y}*{T}/({P}+0.01), {T}/({P}+0.01), {Y}*(1-{T})/(1-{P}+0.01), (1-{T})/(1-{P}+0.01)) as res
    FROM 
    {table}
    ;
    """
        )["res"][0]
        .replace("]", "")
        .replace(" ", "")
        .split("[")
    )
    # Process the query results
    res = [i.split(",") for i in res if i != ""]
    res = np.array([[float(j) for j in i if j != ""] for i in res])
    # Calculate IPTW estimates
    result = res[0, :] / res[1, :] - res[2, :] / res[3, :]
    ATE = np.mean(result)
    # Calculate standard deviation
    std = np.std(result)
    # Calculate t-value
    t_value = ATE / std
    # Calculate p-value
    p_value = (1 - stats.t.cdf(abs(t_value), n - 1)) * 2
    # Calculate 95% confidence interval
    confidence_interval = [ATE - 1.96 * std, ATE + 1.96 * std]
    # Return results
    return {
        "ATE": ATE,
        "stddev": std,
        "p_value": p_value,
        "95% confidence_interval": confidence_interval,
    } 
[docs]def ATEestimator(df, Y, T, B=500):
    """
    Estimate the Average Treatment Effect (ATE) using a simple difference in means approach.
    :param table: the name of the input data table.
    :type table: str, required
    :param Y: the column name of the outcome variable.
    :type Y: str, required
    :param T: the column name of the treatment variable.
    :type T: str, required
    :param B: the number of bootstrap samples, default is 500.
    :type B: int, optional
    :return: dict, containing the following key-value pairs:
    'ATE': Average Treatment Effect.
    'stddev': Standard deviation.
    'p_value': p-value.
    '95% confidence_interval': 95% confidence interval.
    Example
    ----------
    .. code-block:: python
        import fast_causal_inference
        table = 'test_data_small'
        df = fast_causal_inference.readClickHouse(table)
        Y = 'numerator'
        T = 'treatment'
        import fast_causal_inference.dataframe.statistics as S
        S.ATEestimator(df,Y,T,B=500)
    """
    table = df_2_table(df)
    # Create SQL instance
    sql_instance = create_sql_instance()
    # Get the number of rows in the table
    n = int(sql_instance.sql(f"select count(*) as cnt from {table}")["cnt"][0])
    # Execute SQL query to compute ATE estimator using a simple difference in means approach
    res = (
        sql_instance.sql(
            f"""WITH (
      SELECT DistributedNodeRowNumber(1)(0)
      FROM {table}
    ) AS pa
    SELECT
      BootStrapMulti('sum:1;sum:1;sum:1;sum:1',  {n}, {B}, pa)(
      {Y}*{T},{T},{Y}*(1-{T}),(1-{T})) as res
    FROM 
    {table}
    ;
    """
        )["res"][0]
        .replace("]", "")
        .replace(" ", "")
        .split("[")
    )
    # Process the query results
    res = [i.split(",") for i in res if i != ""]
    res = np.array([[float(j) for j in i if j != ""] for i in res])
    # Calculate IPTW estimates
    result = res[0, :] / res[1, :] - res[2, :] / res[3, :]
    # Compute the ATE
    ATE = np.mean(result)
    # Compute standard deviation
    std = np.std(result)
    # Compute t-value
    t_value = ATE / std
    # Compute p-value
    p_value = (1 - stats.t.cdf(abs(t_value), n - 1)) * 2
    # Compute 95% confidence interval
    confidence_interval = [ATE - 1.96 * std, ATE + 1.96 * std]
    # Return the results
    return {
        "ATE": ATE,
        "stddev": std,
        "p_value": p_value,
        "95% confidence_interval": confidence_interval,
    }