import pickle
import warnings
from fast_causal_inference.util import (
    create_sql_instance,
)
from fast_causal_inference.dataframe.dataframe import DataFrame, readClickHouse
from fast_causal_inference.dataframe.df_base import df_2_table, table_2_df
from fast_causal_inference.lib.ols import *
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from graphviz import Digraph
import time
import math
from fast_causal_inference.dataframe.dataframe import DataFrame, readClickHouse
from fast_causal_inference.lib.ols import Ols
from fast_causal_inference.dataframe.regression import check_table, check_columns
from fast_causal_inference.util import SqlGateWayConn, ClickHouseUtils
from fast_causal_inference.lib.causaltree import (
    FeatNames,
    FilterSchema,
    CausalTreeclass,
    fdrcorrection,
    CTRegressionTree,
    auto_wrap_text,
)
def df_2_table(df):
    new_df = df.materializedView()
    return new_df.getTableName()
def table_2_df(table):
    return readClickHouse(table)
from statsmodels.stats.multitest import fdrcorrection
import pickle
import warnings
warnings.filterwarnings("ignore")
# Define a class to store the results of Lift Gain Curve
class LiftGainCurveResult:
    def __init__(self, data):
        # Initialize the result as a DataFrame with specified column names
        self.result = pd.DataFrame(
            data, columns=["ratio", "lift", "gain", "ate", "ramdom_gain"]
        )
    def __str__(self):
        # Return the string representation of the result
        return str(self.result)
    def summary(self):
        # Print the summary of the result
        print(self.result)
    def get_result(self):
        # Return the result
        return self.result
# Function to calculate lift gain
[docs]def get_lift_gain(ITE, Y, T, df, normalize=True, K=1000, discrete_treatment=True):
    """
    Calculate the uplift & gain.
    Parameters
    ----------
    ITE : str
        The Individual Treatment Effect column.
    Y : str
        The outcome variable column.
    T : str
        The treatment variable column.
    df : DataFrame
        The input data.
    normalize : bool, optional
        Whether to normalize the result, default is True.
    K : int, optional
        The number of bins for discretization, default is 1000.
    discrete_treatment : bool, optional
        Whether the treatment is discrete, default is True.
    Returns
    -------
    LiftGainCurveResult
        An object containing the result of the uplift & gain calculation.
    Example
    -------
    .. code-block:: python
        import fast_causal_inference
        from fast_causal_inference.dataframe.uplift import *
        Y='y'
        T='treatment'
        table = 'test_data_small'
        X = 'x1+x2+x3+x4+x5+x_long_tail1+x_long_tail2'
        needcut_X = 'x1+x2+x3+x4+x5+x_long_tail1+x_long_tail2'
        df = readClickHouse(table)
        df_train, df_test = df.split(0.5)
        hte = CausalTree(depth = 3,min_sample_ratio_leaf=0.001)
        hte.fit(Y,T,X,needcut_X,df_train)
        df_train_pred = hte.effect(df=df_train,keep_col='*')
        df_test_pred = hte.effect(df=df_test,keep_col='*')
        lift_train = get_lift_gain("effect", Y, T, df_train_pred,discrete_treatment=True, K=100)
        lift_test = get_lift_gain("effect", Y, T, df_test_pred,discrete_treatment=True, K=100)
        print(lift_train,lift_test)
        hte_plot([lift_train,lift_test],labels=['train','test'])
        # auuc: 0.6624369283393814
        # auuc: 0.6532554148698826
        #        ratio      lift      gain  ate  ramdom_gain
        # 0   0.009990  2.164241  0.021621  1.0     0.009990
        # 1   0.019980  2.131245  0.042582  1.0     0.019980
        # 2   0.029970  2.056440  0.061632  1.0     0.029970
        # 3   0.039960  2.177768  0.087024  1.0     0.039960
        # 4   0.049950  2.175329  0.108658  1.0     0.049950
        # ..       ...       ...       ...  ...          ...
        # 95  0.959241  1.015223  0.973843  1.0     0.959241
        # 96  0.969431  1.010023  0.979147  1.0     0.969431
        # 97  0.979620  1.006843  0.986324  1.0     0.979620
        # 98  0.989810  1.003508  0.993283  1.0     0.989810
        # 99  1.000000  1.000000  1.000000  1.0     1.000000
        # [100 rows x 5 columns]        ratio      lift      gain  ate  ramdom_gain
        # 0   0.009810  1.948220  0.019112  1.0     0.009810
        # 1   0.019620  2.221654  0.043588  1.0     0.019620
        # 2   0.029429  2.419752  0.071212  1.0     0.029429
        # 3   0.039239  2.288460  0.089797  1.0     0.039239
        # 4   0.049049  2.343432  0.114943  1.0     0.049049
        # ..       ...       ...       ...  ...          ...
        # 95  0.959960  1.014897  0.974260  1.0     0.959960
        # 96  0.969970  1.011624  0.981245  1.0     0.969970
        # 97  0.979980  1.009358  0.989150  1.0     0.979980
        # 98  0.989990  1.006340  0.996267  1.0     0.989990
        # 99  1.000000  1.000000  1.000000  1.0     1.000000
        # [100 rows x 5 columns]
    """
    # Construct the SQL query
    sql = (
        "select lift("
        + str(ITE)
        + ","
        + str(Y)
        + ","
        + str(T)
        + ","
        + str(K)
        + ","
        + str(discrete_treatment).lower()
        + ") from "
        + str(df_2_table(df))
        + " limit 100000"
    )
    # Create an SQL instance
    sql_instance = SqlGateWayConn.create_default_conn()
    # Execute the SQL query and get the result
    result = sql_instance.sql(sql)
    # Select specific columns from the result
    result = result[["ratio", "lift", "gain", "ate", "ramdom_gain"]]
    # Replace 'nan' with np.nan
    result = result.replace("nan", np.nan)
    # Convert the data type to float
    result = result.astype(float)
    # Drop rows with missing values
    result = result.dropna()
    # Normalize the result if required
    if normalize:
        result = result.div(np.abs(result.iloc[-1, :]), axis=1)
    # Calculate AUUC
    auuc = result["gain"].sum() / result.shape[0]
    print("auuc:", auuc)
    # Return the result as a LiftGainCurveResult object
    return LiftGainCurveResult(result) 
# Function to plot HTE
[docs]def hte_plot(results, labels=[]):
    """
    Plot the uplift & gain.
    Parameters
    ----------
    results : list
        A list of LiftGainCurveResult objects to be plotted.
    labels : list, optional
        A list of labels for the results, default is an empty list.
    Returns
    -------
    None
        This function will display a plot.
    """
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=(12, 4.8))
    # Generate labels if not provided
    if len(labels) == 0:
        labels = [f"model_{i + 1}" for i in range(len(results))]
    # Plot the results
    for i in range(len(results)):
        result = results[i].get_result()
        auuc = round(result["gain"].sum() / result.shape[0], 2)
        ax1.plot(result["ratio"], result["lift"], label=labels[i])
        ax2.plot(
            [0] + list(result["ratio"]),
            [0] + list(result["gain"]),
            label=labels[i] + f"(auuc:{auuc})",
        )
    # Plot the ATE and random gain
    ax1.plot(result["ratio"], result["ate"])
    ax2.plot([0] + list(result["ratio"]), [0] + list(result["ramdom_gain"]))
    # Set the titles and legends
    ax1.set_title("Cumulative Lift Curve")
    ax1.legend()
    ax2.set_title("Cumulative Gain Curve")
    ax2.legend()
    # Set the title for the figure
    fig.suptitle("Lift and Gain Curves")
    # Display the figure
    plt.show() 
# global sql_instance
# sql_instance = ais.create()
def SelectSchema(schema):
    return schema
def FeatNames(schema):
    if "," in schema:
        schemaArray = schema.split(",")
    else:
        schemaArray = [schema]
    return schemaArray
def FilterSchema(schemaArray):
    tmp = [i + " is not null" for i in schemaArray]
    return " and ".join(tmp)
# CausalTree class
class CausalTreeclass:
    def __init__(
        self,
        dfName="dat",
        threshold=0.01,
        maxDepth=3,
        whereCond="",
        nodePosition="L",
        impurity=0,
        father_split_feature="",
        father_split_feature_Categories=[],
        whereCond_new=[],
        nodesSet=[],
    ):
        self.dfName = dfName
        self.threshold = threshold
        self.maxDepth = maxDepth
        self.whereCond = whereCond
        self.nodePosition = nodePosition
        self.impurity = impurity
        self.father_split_feature = father_split_feature
        self.father_split_feature_Categories = father_split_feature_Categories
        self.whereCond_new = whereCond_new
        self.nodesSet = nodesSet
        self.leftNode = 0
        self.rightNode = 0
        self.nodeSize = 0
        self.controlcount = 0
        self.treatcount = 0
        self.splitFeat = ""
        self.splitIndex = 0
        self.splitpoint = ""
        self.prediction = 0
        self.maxImpurGain = 0
        self.isLeaforNot = False
        self.splitpoint_pdf = pd.DataFrame()
        self.allsplitpoint_pdf = pd.DataFrame()
        self.featvalues_dict = dict()
        self.inferenceSet = [
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
            1.0,
        ]
        self.dat_cnt = 0
        self.depth = 3
        self.sql_instance = create_sql_instance()
    def get_global_values(self, dat_cnt, depth, featNames):
        self.dat_cnt = dat_cnt
        self.depth = depth
        self.featNames = featNames
    def compute_df(self):
        threshold = self.threshold
        if self.maxDepth == 0:
            self.isLeaforNot = True
            print("Reach the maxDepth, stop as a leaf node")
            return
        if self.whereCond != "":
            whereCond_ = "AND" + self.whereCond
        else:
            whereCond_ = ""
        dfName = self.dfName
        sql_instance = self.sql_instance
        row = sql_instance.sql(
            f"""SELECT treatment as T,count(Y) as cnt,avg(Y) as mean,varSamp(Y) as var FROM {dfName} 
                                WHERE if_test = 0 {whereCond_}
                                group by treatment order by treatment  """
        )
        self.controlcount = int(row["cnt"][0])
        self.treatcount = int(row["cnt"][1])
        self.nodeSize = self.controlcount + self.treatcount
        if self.nodeSize / self.dat_cnt < threshold:
            self.isLeaforNot = True
            print("sample size is too small, stop as a leaf node")
            return
        if self.nodeSize == 0:
            self.isLeaforNot = True
            print(" sample size = 0,  stop as a leaf node")
            return
        if float(row["var"][0]) == 0 and float(row["var"][1]) == 0:
            self.isLeaforNot = True
            print(" var = 0, stop as a leaf node ")
            return
        t1 = time.time()
        cols = [
            "featName",
            "featValue",
            "cnt1",
            "y1",
            "y1_square",
            "cnt0",
            "y0",
            "y0_square",
        ]
        splitpoint_data_pdf = pd.DataFrame([], columns=cols)
        featNames = self.featNames
        for i in range(len(featNames) // 100 + 1):
            x_list = featNames[i * 100 : min((i + 1) * 100, len(featNames))]
            x_list_string1 = "'" + "', '".join(x_list) + "'"
            x_list_string2 = ",".join(x_list)
            try:
                res = sql_instance.sql(
                    f"""SELECT arrayJoin(GroupSet({x_list_string1})(Y, treatment, {x_list_string2})) as a 
                FROM {dfName} WHERE if_test = 0 {whereCond_} order by a.1,a.3,a.2 asc  limit 100000"""
                )["a"].tolist()
                data_list = [
                    i.replace("[", "").replace("]", "").replace(" ", "").split(",")
                    for i in res
                ]
                df = pd.DataFrame(
                    data_list,
                    columns=[
                        "featName",
                        "treatment",
                        "featValue",
                        "cnt",
                        "y",
                        "y_square",
                    ],
                )
            except:
                print(
                    sql_instance.sql(
                        f"""SELECT arrayJoin(GroupSet({x_list_string1})(Y, treatment, {x_list_string2})) as a 
                FROM {dfName} WHERE if_test = 0 {whereCond_} order by a.1,a.3,a.2 asc"""
                    )
                )
            df[["treatment", "cnt", "y", "y_square"]] = df[
                ["treatment", "cnt", "y", "y_square"]
            ].astype(float)
            df[["featName", "featValue"]] = df[["featName", "featValue"]].astype(str)
            df["featName"] = [i.replace("'", "") for i in list(df["featName"])]
            df0 = df[df["treatment"] == 0].drop("treatment", axis=1)
            df1 = df[df["treatment"] == 1].drop("treatment", axis=1)
            df_new = pd.merge(df1, df0, on=["featName", "featValue"])
            df_new.columns = cols
            splitpoint_data_pdf = pd.concat([splitpoint_data_pdf, df_new], axis=0)
        splitpoint_data_pdf["tau"] = (
            splitpoint_data_pdf["y1"] / splitpoint_data_pdf["cnt1"]
            - splitpoint_data_pdf["y0"] / splitpoint_data_pdf["cnt0"]
        )
        splitpoint_data_pdf = splitpoint_data_pdf.sort_values(
            by=["featName", "tau"], ascending=False
        )
        tmp = splitpoint_data_pdf.groupby(by=["featName"], as_index=False)[
            "featValue"
        ].count()
        tmp.columns = ["featName", "featpoint_num"]
        splitpoint_data_pdf = pd.merge(splitpoint_data_pdf, tmp, on=["featName"])
        splitpoint_data_pdf = splitpoint_data_pdf.dropna()
        splitpoint_data_pdf_copy = splitpoint_data_pdf.copy()
        self.allsplitpoint_pdf = splitpoint_data_pdf_copy
        splitpoint_data_pdf = splitpoint_data_pdf[splitpoint_data_pdf.featpoint_num > 1]
        featNames_ = list(set(splitpoint_data_pdf["featName"]))
        featValuesall_list = []
        featName_list = []
        featValue_list = []
        splitpoint_list = []
        cnt0_list = []
        cnt1_list = []
        splitpoint_list = []
        for featName_ in featNames_:
            featValuesall = list(
                splitpoint_data_pdf[(splitpoint_data_pdf["featName"] == featName_)][
                    "featValue"
                ]
            )
            cnt0s = list(
                splitpoint_data_pdf[(splitpoint_data_pdf["featName"] == featName_)][
                    "cnt0"
                ]
            )
            cnt1s = list(
                splitpoint_data_pdf[(splitpoint_data_pdf["featName"] == featName_)][
                    "cnt1"
                ]
            )
            featValuesall_list.append(featValuesall)
            for i in range(len(featValuesall) - 1):
                cnt0 = np.sum(cnt0s[0 : i + 1])
                cnt1 = np.sum(cnt1s[0 : i + 1])
                cnt0_list.append(cnt0)
                cnt1_list.append(cnt1)
                featName_list.append(featName_)
                splitpoint_list.append(dict({featName_: featValuesall[0 : i + 1]}))
                featValue_list.append(featValuesall[0 : i + 1])
        splitpoint_pdf = pd.DataFrame(
            zip(featName_list, featValue_list, splitpoint_list, cnt0_list, cnt1_list),
            columns=["featName", "featValue", "splitpoint", "cnt0", "cnt1"],
        )
        #         print("splitpoint_pdf_before:\n",splitpoint_pdf)
        splitpoint_pdf = splitpoint_pdf[
            (splitpoint_pdf["cnt0"] > 100)
            & (splitpoint_pdf["cnt1"] > 100)
            & (
                splitpoint_pdf["cnt0"] + splitpoint_pdf["cnt1"]
                > threshold * self.dat_cnt
            )
            & (
                self.nodeSize - splitpoint_pdf["cnt0"] - splitpoint_pdf["cnt1"]
                > threshold * self.dat_cnt
            )
        ]
        self.splitpoint_pdf = splitpoint_pdf
        self.featvalues_dict = dict(zip(featNames_, featValuesall_list))
        t2 = time.time()
    def splitcond_sql(self, splitpoint):
        featName = list(splitpoint.keys())[0]
        featValue = list(splitpoint.values())[0]
        x = self.sql_instance.sql(f"desc {self.dfName};")
        cols_type = dict(zip(x["name"], x["type"]))
        if cols_type[featName] == "String":
            featValue = ",".join(["'" + str(i) + "'" for i in featValue])
        else:
            featValue = ",".join([str(i) for i in featValue])
        left_condition_tree = f"""({featName} in ({featValue}))"""
        right_condition_tree = f"""not ({featName} in ({featValue}))"""
        return left_condition_tree, right_condition_tree
    def calculate_impurity_new(self, x):
        allsplitpoint_pdf = self.allsplitpoint_pdf  # TODO
        featName = x["featName"]
        featValue = x["featValue"]
        # left impurity
        (y1, y1_square, cnt1, y0, y0_square, cnt0) = list(
            allsplitpoint_pdf[
                (allsplitpoint_pdf["featValue"].isin(featValue))
                & (allsplitpoint_pdf["featName"] == featName)
            ][["y1", "y1_square", "cnt1", "y0", "y0_square", "cnt0"]].sum()
        )
        y1 = y1 / cnt1
        y0 = y0 / cnt0
        y1_square = y1_square / cnt1
        y0_square = y0_square / cnt0
        tau = y1 - y0
        tr_var = y1_square - y1**2
        con_var = y0_square - y0**2
        left_effect = 0.5 * tau * tau * (cnt1 + cnt0) - 0.5 * 2 * (cnt1 + cnt0) * (
            tr_var / cnt1 + con_var / cnt0
        )
        # right impurity
        (y1, y1_square, cnt1, y0, y0_square, cnt0) = list(
            allsplitpoint_pdf[
                (~(allsplitpoint_pdf["featValue"].isin(featValue)))
                & (allsplitpoint_pdf["featName"] == featName)
            ][["y1", "y1_square", "cnt1", "y0", "y0_square", "cnt0"]].sum()
        )
        y1 = y1 / cnt1
        y0 = y0 / cnt0
        y1_square = y1_square / cnt1
        y0_square = y0_square / cnt0
        tau = y1 - y0
        tr_var = y1_square - y1**2
        con_var = y0_square - y0**2
        right_effect = 0.5 * tau * tau * (cnt1 + cnt0) - 0.5 * 2 * (cnt1 + cnt0) * (
            tr_var / cnt1 + con_var / cnt0
        )
        return left_effect, right_effect
    def calculate_impurity_original(self):
        sql_instance = self.sql_instance
        sql = f""" select\
                sum(if(treatment=1,1,0)) as cnt1,  \
                sum(if(treatment=1,Y,0))/sum(if(treatment=1,1,0)) as y1, \
                sum(if(treatment=1,Y*Y,0))/sum(if(treatment=1,1,0)) as y1_square, \
                sum(if(treatment=0,1,0)) as cnt0,  \
                sum(if(treatment=0,Y,0))/sum(if(treatment=0,1,0)) as y0, \
                sum(if(treatment=0,Y*Y,0))/sum(if(treatment=0,1,0)) as y0_square \
         from {self.dfName} where if_test = 0 """
        row = sql_instance.sql(sql).iloc[0, :]
        cnt1, y1, y1_square, cnt0, y0, y0_square = (
            int(row.cnt1),
            float(row.y1),
            float(row.y1_square),
            int(row.cnt0),
            float(row.y0),
            float(row.y0_square),
        )
        tau = y1 - y0
        tr_var = y1_square - y1**2
        con_var = y0_square - y0**2
        effect = 0.5 * tau * tau * (cnt1 + cnt0) - 0.5 * 2 * (cnt1 + cnt0) * (
            tr_var / cnt1 + con_var / cnt0
        )
        return effect
    def getTreeID(self):
        nodePosition = self.nodePosition
        lengthStr = len(nodePosition)
        if lengthStr == 1:
            result = 1
        else:
            result = sum([math.pow(2, i) for i in range(lengthStr - 1)]) + 1
        for i in range(lengthStr):
            x = 1 if nodePosition[i] == "R" else 0
            result += math.pow(2, lengthStr - i - 1) * x
        return result
    def get_node_type(self):
        if self.isLeaforNot == False:
            return "internal"
        else:
            return "leaf"
    def getBestSplit(self):
        #         print("-----getBestSplit-----")
        splitpoint_pdf = self.splitpoint_pdf
        if splitpoint_pdf.shape[0] == 0:
            self.isLeaforNot = True
            print(
                "no split points that satisfy the condition,stop splitting as a leaf node"
            )
            return
        splitpoint_pdf[["leftImpurity", "rightImpurity"]] = list(
            splitpoint_pdf.apply(self.calculate_impurity_new, axis=1)
        )
        splitpoint_pdf["ImpurityGain"] = (
            splitpoint_pdf["leftImpurity"]
            + splitpoint_pdf["rightImpurity"]
            - self.impurity
        )
        splitpoint_pdf = splitpoint_pdf.sort_values(by="ImpurityGain", ascending=False)
        splitpoint_pdf = splitpoint_pdf[(splitpoint_pdf["ImpurityGain"] > 0)]
        self.splitpoint_pdf = splitpoint_pdf[
            ["featName", "featValue", "splitpoint", "ImpurityGain"]
        ]
        if splitpoint_pdf.shape[0] == 0:
            self.isLeaforNot = True
            print(
                "no split points that satisfy the condition,stop splitting as a leaf node"
            )
            return
        bestFeatureName = splitpoint_pdf.iloc[0]["featName"]
        bestFeatureIndex = splitpoint_pdf.iloc[0]["featValue"]
        bestleftImpurity = splitpoint_pdf.iloc[0]["leftImpurity"]
        bestrightImpurity = splitpoint_pdf.iloc[0]["rightImpurity"]
        maxImpurityGain = splitpoint_pdf.iloc[0]["ImpurityGain"]
        bestsplitpoint = splitpoint_pdf.iloc[0]["splitpoint"]
        self.maxImpurGain = maxImpurityGain
        self.splitFeat = bestFeatureName
        self.splitIndex = bestFeatureIndex
        self.splitpoint = bestsplitpoint
        self.leftImpurity = bestleftImpurity
        self.rightImpurity = bestrightImpurity
    def buildTree(self):
        #         global nodesSet
        self.nodesSet.append(self)
        #         print(self.nodesSet)
        if self.whereCond == "":
            self.impurity = self.calculate_impurity_original()
        #             print("root impurity:",self.impurity)
        self.compute_df()
        if self.isLeaforNot:
            return
        else:
            self.getBestSplit()
            if self.isLeaforNot:
                return
        #         print("end split")
        whereConditionSql = "" if (self.whereCond == "") else self.whereCond + " AND "
        leftcondition_tree, rightcondition_tree = self.splitcond_sql(self.splitpoint)
        leftCond = f"{whereConditionSql} ( {leftcondition_tree} ) "
        rightCond = f"{whereConditionSql} ( {rightcondition_tree} ) "
        leftChildDfName = self.nodePosition + "L"
        rightChildDfName = self.nodePosition + "R"
        # TODO:
        father_split_feature = self.splitFeat
        Categories = self.featvalues_dict[father_split_feature]
        left_Categories = self.splitpoint[father_split_feature]
        right_Categories = []
        for i in Categories:
            if i not in left_Categories:
                right_Categories.append(i)
        left_tmp = {}
        right_tmp = {}
        #         print("self.whereCond_new",self.whereCond_new)
        #         print(father_split_feature,left_Categories,right_Categories)
        left_tmp[father_split_feature] = left_Categories
        right_tmp[father_split_feature] = right_Categories
        leftCond_new = self.whereCond_new.copy()
        leftCond_new.append(left_tmp)
        rightCond_new = self.whereCond_new.copy()
        rightCond_new.append(right_tmp)
        #         print("leftCond_new,rightCond_new",leftCond_new,rightCond_new)
        leftChild = CausalTreeclass(
            self.dfName,
            self.threshold,
            self.maxDepth - 1,
            leftCond,
            self.nodePosition + "L",
            self.leftImpurity,
            father_split_feature,
            left_Categories,
            leftCond_new,
            self.nodesSet,
        )
        rightChild = CausalTreeclass(
            self.dfName,
            self.threshold,
            self.maxDepth - 1,
            rightCond,
            self.nodePosition + "R",
            self.rightImpurity,
            father_split_feature,
            right_Categories,
            rightCond_new,
            self.nodesSet,
        )
        leftChild.get_global_values(self.dat_cnt, self.depth, self.featNames)
        rightChild.get_global_values(self.dat_cnt, self.depth, self.featNames)
        self.leftNode = leftChild
        self.rightNode = rightChild
        #         print("-----split result----")
        #         print("maxImpurGain:",self.maxImpurGain)
        #         print("splitpoint:",self.splitpoint)
        #         print("leftsplitcond,rightsplitcond:",leftcondition_tree,rightcondition_tree)
        #         print("leftImpurity,rightImpurity:",self.leftImpurity,self.rightImpurity)
        print(
            "--------start leftNode - build -- depth: {depth}, nodePosition: {nodePosition}--------".format(
                depth=self.depth - self.maxDepth, nodePosition=self.nodePosition + "L"
            )
        )
        self.leftNode.buildTree()
        print(
            "--------start rightNode - build -- depth: {depth}, nodePosition: {nodePosition}--------".format(
                depth=self.depth - self.maxDepth, nodePosition=self.nodePosition + "R"
            )
        )
        self.rightNode.buildTree()
    def visualization(self):
        if self.isLeaforNot:
            return (
                "\n" + "Prediction is: " + str(self.prediction) + " " + self.whereCond
            )
        else:
            return (
                "Level"
                + str(self.maxDepth)
                + "  "
                + self.splitpoint
                + self.maxImpurGain
                + "\n"
                + self.leftNode.visualization()
                + "\n"
                + self.rightNode.visualization()
            )
    def ComputePvalueAndCI(self, zValue, prediction, meanStd):
        pvalue = 2 * (1 - norm.cdf(x=abs(zValue), loc=0, scale=1))
        lowerCI = prediction - 1.96 * (meanStd)
        upperCI = prediction + 1.96 * (meanStd)
        return (pvalue, lowerCI, upperCI)
    def predictNode(self, ate=0, cnt=1):
        if self.whereCond != "":
            whereCond_ = "AND" + self.whereCond
        else:
            whereCond_ = ""
        level = self.depth - self.maxDepth
        TreeID = self.getTreeID()
        isLeaf = True if (self.isLeaforNot) else False
        whereCond_new = self.whereCond_new
        ate = ate
        sql = f"""SELECT \
        treatment, count(*) as cnt, avg(Y) as mean, varSamp(Y) as var \
        FROM {self.dfName} \
        WHERE treatment in (0,1) {whereCond_} and if_test=1 \
        GROUP BY treatment \
        """
        sql_instance = self.sql_instance
        statData_pdf = sql_instance.sql(sql)
        try:
            statData_pdf = statData_pdf.astype(float)
            y1_column = statData_pdf[statData_pdf["treatment"] == 1]
            y0_column = statData_pdf[statData_pdf["treatment"] == 0]
            treatedCount = int(list(y1_column["cnt"])[0])
            controlCount = int(list(y0_column["cnt"])[0])
            treatedLabelAvg = float(list(y1_column["mean"])[0])
            controlLabelAvg = float(list(y0_column["mean"])[0])
            treatedLabelVar = float(list(y1_column["var"])[0])
            controlLabelVar = float(list(y0_column["var"])[0])
            self.prediction = treatedLabelAvg - controlLabelAvg
            ratio = (treatedCount + controlCount) / cnt
        except:
            # print(statData,whereCond_)
            self.inferenceSet = list(
                [
                    TreeID,
                    level,
                    isLeaf,
                    whereCond_,
                    whereCond_new,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                ]
            )
            return
        # if treatedLabelAvg == 0 or controlLabelAvg == 0:
        #     self.inferenceSet = list(
        #         [TreeID, level, isLeaf, whereCond_, whereCond_new, ratio, self.prediction, treatedCount, controlCount,
        #          treatedLabelAvg, controlLabelAvg, math.sqrt(treatedLabelVar), math.sqrt(controlLabelVar),
        #          0, 0, 1, 0, 0, 0,
        #          0, 0, 1, 0, 0, 0,
        #          0, 0, 1, 0, 0, 0,
        #          0, 0, 1, 0, 0, 0
        #          ])  # TODO:fix zero bug
        try:
            estPoint1 = treatedLabelAvg - controlLabelAvg
            std1 = math.sqrt(
                treatedLabelVar / treatedCount + controlLabelVar / controlCount
            )
            zValue1 = estPoint1 / std1
            (pvalue1, lowerCI1, upperCI1) = self.ComputePvalueAndCI(
                zValue1, estPoint1, std1
            )
            estPoint2 = treatedLabelAvg - controlLabelAvg - ate
            std2 = std1
            zValue2 = estPoint2 / std2
            (pvalue2, lowerCI2, upperCI2) = self.ComputePvalueAndCI(
                zValue2, estPoint2, std2
            )
            estPoint3 = treatedLabelAvg / controlLabelAvg - 1
            std3 = std1
            zValue3 = zValue1
            pvalue3 = pvalue1
            lowerCI3 = estPoint3 - 1.96 * (std1) * estPoint3 / estPoint1
            upperCI3 = estPoint3 + 1.96 * (std1) * estPoint3 / estPoint1
            estPoint4 = (treatedLabelAvg - ate) / controlLabelAvg - 1
            std4 = std2
            zValue4 = zValue2
            pvalue4 = pvalue2
            if TreeID == 1:
                estPoint2, estPoint4 = 0, 0
                lowerCI4 = 0
                upperCI4 = 0
            else:
                lowerCI4 = estPoint4 - 1.96 * (std2) * estPoint4 / estPoint2
                upperCI4 = estPoint4 + 1.96 * (std2) * estPoint4 / estPoint2
            self.inferenceSet = list(
                [
                    TreeID,
                    level,
                    isLeaf,
                    whereCond_,
                    whereCond_new,
                    ratio,
                    self.prediction,
                    treatedCount,
                    controlCount,
                    treatedLabelAvg,
                    controlLabelAvg,
                    math.sqrt(treatedLabelVar),
                    math.sqrt(controlLabelVar),
                    estPoint1,
                    std1,
                    pvalue1,
                    zValue1,
                    lowerCI1,
                    upperCI1,
                    estPoint2,
                    std2,
                    pvalue2,
                    zValue2,
                    lowerCI2,
                    upperCI2,
                    estPoint3,
                    std3,
                    pvalue3,
                    zValue3,
                    lowerCI3,
                    upperCI3,
                    estPoint4,
                    std4,
                    pvalue4,
                    zValue4,
                    lowerCI4,
                    upperCI4,
                ]
            )
        except:
            self.inferenceSet = list(
                [
                    TreeID,
                    level,
                    isLeaf,
                    whereCond_,
                    whereCond_new,
                    ratio,
                    self.prediction,
                    treatedCount,
                    controlCount,
                    treatedLabelAvg,
                    controlLabelAvg,
                    math.sqrt(treatedLabelVar),
                    math.sqrt(controlLabelVar),
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                ]
            )  # TODO:fix zero bug
# define tree structure
class CTDecisionNode:
    def __init__(
        self,
        node_id=0,
        nodeType="",
        node_order="",
        nodePosition="",
        count_ratio=0,
        controlCount=0,
        treatedCount=0,
        treatedAvg=0,
        controlAvg=0,
        splitType="",
        gain=0,
        impurity=0,
        prediction=0,
        prediction_new=[],
        featureName="",
        featureIndex="",
        father_split_feature="",
        father_split_feature_Categories=[],
        pvalues=[],
        qvalues=[],
        children=None,
        whereCond="",
        whereCond_new=[],
    ):
        self.node_id = node_id
        self.nodeType = nodeType
        self.node_order = node_order
        self.nodePosition = nodePosition
        self.count_ratio = count_ratio
        self.controlCount = controlCount
        self.treatedCount = treatedCount
        self.treatedAvg = treatedAvg
        self.controlAvg = controlAvg
        self.splitType = splitType
        self.gain = gain
        self.impurity = impurity
        self.prediction = prediction
        self.prediction_new = prediction_new
        self.featureName = featureName
        self.featureIndex = featureIndex
        self.father_split_feature = father_split_feature
        self.father_split_feature_Categories = father_split_feature_Categories
        self.pvalues = pvalues
        self.qvalues = qvalues
        self.whereCond = whereCond
        self.whereCond_new = whereCond_new
        self.children = children
    def get_dict(self):
        child_dic = (
            []
            if self.children == None
            else (self.children[0].get_dict(), self.children[1].get_dict())
        )
        return {
            "node_id": self.node_id,
            "nodeType": self.nodeType,
            "father_split_feature_Categories": self.father_split_feature_Categories,
            "count_ratio": self.count_ratio,
            "controlCount": self.controlCount,
            "treatedCount": self.treatedCount,
            "controlAvg": self.controlAvg,
            "treatedAvg": self.treatedAvg,
            "father_split_feature": self.father_split_feature,
            "pvalues": self.pvalues,
            "tau_i": self.prediction,
            #                 "node_order": self.node_order,
            #                 "nodePosition":self.nodePosition,
            #                 "tau_i_new":self.prediction_new,
            #                 "splitType": self.splitType,
            #                 "gain": self.gain,
            #                 "impurity": self.impurity,
            #                 "featureName": self.featureName,
            #                 "featureIndex": self.featureIndex,
            #                 "qvalues":self.qvalues,
            #                 "whereCond":self.whereCond,
            #                 "whereCond_new":self.whereCond_new,
            "children": child_dic,
        }
    # causaltree_todict for tree.plot
class CTRegressionTree:
    def __init__(self, tree=0, total_count=1):
        self.tree = tree
        self.total_count = total_count
    def get_decision_rules(self, node_order="root"):
        tree = self.tree
        node_id = tree.getTreeID()
        node_type = tree.get_node_type()
        node_order = node_order
        # 11.1 sort _Category values
        x = tree.father_split_feature_Categories
        #         x = [int(i) for i in x]
        x.sort()
        #         father_split_feature_Categories = [str(i) for i in x]
        if node_type == "internal":
            gain = tree.maxImpurGain
            feature_name = tree.splitFeat
            feature_index = tree.splitIndex
            split_type = "categorical"
            node_impurity = tree.impurity
            #             Categories = tree.featvalues_dict[feature_name]
            #             left_Categories = list(tree.splitIndex.split(','))
            #             right_Categories = []
            #             for i in Categories:
            #                 if i not in left_Categories:
            #                     right_Categories.append(i)
            left = CTRegressionTree(tree=tree.leftNode, total_count=self.total_count)
            right = CTRegressionTree(tree=tree.rightNode, total_count=self.total_count)
            children = (
                left.get_decision_rules("left"),
                right.get_decision_rules("right"),
            )
        else:
            gain = None
            feature_name = None
            feature_index = None
            split_type = None
            node_impurity = None
            #             left_Categories = None
            #             right_Categories = None
            children = None
        ctDecisionNode = CTDecisionNode(
            node_id=node_id,
            nodeType=node_type,
            node_order=node_order,
            count_ratio=round(tree.inferenceSet["ratio"] * 100, 2),
            treatedCount=tree.inferenceSet["treatedCount"],
            controlCount=tree.inferenceSet["controlCount"],
            treatedAvg=round(tree.inferenceSet["treatedAvg"], 4),
            controlAvg=round(tree.inferenceSet["controlAvg"], 4),
            nodePosition=tree.nodePosition,
            prediction=round(tree.prediction, 4),
            prediction_new=[
                float(tree.inferenceSet["estPoint" + str(i)]) for i in range(1, 5)
            ],
            splitType=split_type,
            gain=gain,
            impurity=node_impurity,
            featureName=feature_name,
            featureIndex=feature_index,
            father_split_feature=tree.father_split_feature,
            father_split_feature_Categories=tree.father_split_feature_Categories,
            pvalues=[
                round(float(tree.inferenceSet["pvalue" + str(i)]), 2)
                for i in range(1, 5)
            ],
            qvalues=[
                round(float(tree.inferenceSet["qvalue" + str(i)]), 2)
                for i in range(1, 5)
            ],
            children=children,
            whereCond=tree.whereCond,
            whereCond_new=tree.whereCond_new,
        )
        return ctDecisionNode
def SelectSchema(schema):
    return schema
def FeatNames(schema):
    if schema == "":
        schemaArray = []
    elif "+" in schema:
        schemaArray = schema.split("+")
    else:
        schemaArray = [schema]
    return schemaArray
def FilterSchema(schemaArray):
    tmp = [i + " is not null" for i in schemaArray]
    return " and ".join(tmp)
def auto_wrap_text(text, max_line_length):
    wrapped_text = textwrap.fill(text, max_line_length)
    return wrapped_text
def check_table(table):
    sql_instance = create_sql_instance()
    x = sql_instance.sql(f"select count(*) as cnt from {table} ")
    if "Code: 60" in x:
        print(x)
        raise ValueError
    elif int(x["cnt"][0]) == 0:
        print("There's no data in the table")
        raise ValueError
    else:
        return 1
def check_numeric_type(table, col):
    sql_instance = create_sql_instance()
    x = sql_instance.sql(f"desc {table}")
    cols_type = dict(zip(x["name"], x["type"]))
    if cols_type[col] not in [
        "UInt8",
        "UInt16",
        "UInt32",
        "UInt64",
        "UInt128",
        "UInt256",
        "Int8",
        "Int16",
        "Int32",
        "Int64",
        "Int128",
        "Int256",
        "Float32",
        "Float64",
    ]:
        print(f"The type of {col} is not numeric")
        return 1
def check_columns(table, cols, cols_nume):
    sql_instance = create_sql_instance()
    x = sql_instance.sql(f"desc {table} ")
    cols_type = dict(zip(x["name"], x["type"]))
    col_list = list(cols_type.keys())
    # check exist
    other_variables = set(cols) - set(col_list)
    if len(other_variables) != 0:
        print(f"variable {other_variables} can't be find in the table {table}")
        raise ValueError
    # numeric exist
    for col in cols_nume:
        if cols_type[col] not in [
            "UInt8",
            "UInt16",
            "UInt32",
            "UInt64",
            "UInt128",
            "UInt256",
            "Int8",
            "Int16",
            "Int32",
            "Int64",
            "Int128",
            "Int256",
            "Float32",
            "Float64",
        ]:
            print(f"The type of {col} is not numeric")
            raise ValueError
[docs]class CausalTree:
    """
    This class implements a Causal Tree for uplift/HTE analysis.
    Parameters
    ----------
    depth : int
        The maximum depth of the tree.
    threshold : float
        The minimum sample ratio for a leaf.
    bin_num : int
        The number of bins for the need_cut_X to cut.
    Example
    -------
    .. code-block:: python
        import fast_causal_inference
        from fast_causal_inference.dataframe.uplift import *
        Y='y'
        T='treatment'
        table = 'test_data_small'
        X = ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
        needcut_X = ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
        df = fast_causal_inference.readClickHouse(table)
        df_train, df_test = df.split(0.5)
        hte = CausalTree(depth = 3,min_sample_ratio_leaf=0.001)
        hte.fit(Y,T,X,needcut_X,df_train)
        treeplot = hte.treeplot() # causal tree plot
        treeplot.render('digraph.gv', view=False) # 可以在digraph.gv.pdf文件里查看tree的完整图片并下载
        print(hte.feature_importance)
        # Output:
        #                    featName    importance
        #     1            x2_buckets  1.015128e+06
        #     0            x1_buckets  2.181346e+05
        #     3            x4_buckets  1.023273e+05
        #     5  x_long_tail1_buckets  5.677131e+04
        #     2            x3_buckets  2.537835e+04
        #     6  x_long_tail2_buckets  2.536951e+04
        #     4            x5_buckets  7.259992e+03
        df_train_pred = hte.effect(df=df_train,keep_col='*')
        df_test_pred = hte.effect(df=df_test,keep_col='*')
        lift_train = get_lift_gain("effect", Y, T, df_train_pred,discrete_treatment=True, K=100)
        lift_test = get_lift_gain("effect", Y, T, df_test_pred,discrete_treatment=True, K=100)
        print(lift_train,lift_test)
        hte_plot([lift_train,lift_test],labels=['train','test'])
        # auuc: 0.6624369283393814
        # auuc: 0.6532554148698826
        #        ratio      lift      gain  ate  ramdom_gain
        # 0   0.009990  2.164241  0.021621  1.0     0.009990
        # 1   0.019980  2.131245  0.042582  1.0     0.019980
        # 2   0.029970  2.056440  0.061632  1.0     0.029970
        # 3   0.039960  2.177768  0.087024  1.0     0.039960
        # 4   0.049950  2.175329  0.108658  1.0     0.049950
        # ..       ...       ...       ...  ...          ...
        # 95  0.959241  1.015223  0.973843  1.0     0.959241
        # 96  0.969431  1.010023  0.979147  1.0     0.969431
        # 97  0.979620  1.006843  0.986324  1.0     0.979620
        # 98  0.989810  1.003508  0.993283  1.0     0.989810
        # 99  1.000000  1.000000  1.000000  1.0     1.000000
        # [100 rows x 5 columns]        ratio      lift      gain  ate  ramdom_gain
        # 0   0.009810  1.948220  0.019112  1.0     0.009810
        # 1   0.019620  2.221654  0.043588  1.0     0.019620
        # 2   0.029429  2.419752  0.071212  1.0     0.029429
        # 3   0.039239  2.288460  0.089797  1.0     0.039239
        # 4   0.049049  2.343432  0.114943  1.0     0.049049
        # ..       ...       ...       ...  ...          ...
        # 95  0.959960  1.014897  0.974260  1.0     0.959960
        # 96  0.969970  1.011624  0.981245  1.0     0.969970
        # 97  0.979980  1.009358  0.989150  1.0     0.979980
        # 98  0.989990  1.006340  0.996267  1.0     0.989990
        # 99  1.000000  1.000000  1.000000  1.0     1.000000
        # [100 rows x 5 columns]
    """
    def __init__(self, depth=3, min_sample_ratio_leaf=0.001, bin_num=10):
        self.depth = depth
        self.threshold = min_sample_ratio_leaf
        self.bin_num = bin_num
        self.Y = ""
        self.T = ""
        self.X = ""
        self.needcut_X = ""
        self.table = ""
        self.tree_structure = []
        self.result_df = []
        self.__sql_instance = SqlGateWayConn.create_default_conn()
        self.__cutbinstring = ""
        self.feature_importance = pd.DataFrame([])
    def __params_input_check(self):
        sql_instance = self.__sql_instance
        if self.Y == "":
            print("missing Y. You should check out the input.")
            raise ValueError
        if self.T == "":
            print("missing T. You should check out the input.")
            raise ValueError
        if self.X == "":
            print("missing X. You should check out the input.")
            raise ValueError
        if self.table == "":
            print("missing table. You should check out the input.")
            raise ValueError
        else:
            sql_instance.sql(f"select {self.Y},{self.T},{self.X} from {self.table}")
    def __table_variables_check(self):
        sql_instance = self.__sql_instance
        table = self.variables["table"]
        Y = self.variables["Y"]
        T = self.variables["T"]
        x_names = self.variables["x_names"]
        cut_x_names = self.variables["cut_x_names"]
        variables = Y + T + x_names + cut_x_names
        check_table(table=table)
        check_columns(table=table, cols=variables, cols_nume=Y + T + cut_x_names)
        res = sql_instance.sql(
            f"select {T[0]} as T from  {table} group by {T[0]} limit 10"
        )["T"].tolist()
        T_value = set([float(i) for i in res])
        if T_value != {0, 1}:
            print("The value of T must be either 0 or 1 ")
            raise ValueError
    def fit(self, Y, T, X, needcut_X, df):
        sql_instance = self.__sql_instance
        table = df_2_table(df)
        table_new = f"{table}_{int(time.time())}_new"
        self.Y = Y
        self.T = T
        self.table = table
        self.X = X
        self.needcut_X = needcut_X
        depth = self.depth
        bin_num = self.bin_num
        self.__params_input_check()
        
        x_names = list(set(X))
        cut_x_names = list(set(needcut_X))
        no_cut_x_names = list(set(x_names) - set(cut_x_names))
        cut_x_names_new = [i + "_buckets" for i in cut_x_names]
        x_names = no_cut_x_names + cut_x_names
        x_names_new = no_cut_x_names + cut_x_names_new
        featNames = x_names_new
        self.variables = {
            "T": [T],
            "Y": [Y],
            "x_names": x_names,
            "cut_x_names": cut_x_names,
            "table": table,
        }
        print("****STEP1.  Table check.")
        ### todo
        print("debug")
        self.__table_variables_check()
        # get bins for cut_x_names
        print("****STEP2.  Bucket the continuous variables(cut_x_names).")
        bins_dict = {}
        quantiles = ",".join(
            [str(i) for i in list(np.linspace(0, 1, bin_num + 1)[1:-1])]
        )
        if len(cut_x_names_new) != 0:
            string = ",".join(
                [f"quantiles({quantiles})({i}) as {i}" for i in cut_x_names]
            )
            result = sql_instance.sql(f"""select {string} from  {table}""")
            bins_dict = {}
            for i in range(len(cut_x_names)):
                col = cut_x_names[i]
                x = result[col][0]
                bins = x.replace("[", "").replace("]", "").split(",")
                if len(bins) == 0:
                    bins = [0]
                bins = list(np.sort(list(set([float(x) for x in bins]))))
                bins_dict[col] = bins
                strings = []
            for i in bins_dict:
                string = f"CutBins({i},{bins_dict[i]},False) as {i}_buckets"
                strings.append(string)
            cutbinstring = ",".join(strings) + ","
        else:
            cutbinstring = ""
        self.bins_dict = bins_dict
        for i in bins_dict:
            bins_dict[i] = [-float("inf")] + bins_dict[i] + [float("inf")]
        if len(no_cut_x_names) != 0:
            no_cut_x_names_string = (
                ",".join([f"{i} as {i}" for i in no_cut_x_names]) + ","
            )
        else:
            no_cut_x_names_string = ""
        self.__cutbinstring = cutbinstring
        self.__no_cut_x_names_string = no_cut_x_names_string
        print("****STEP3.  Create new table for training causaltree: ", table_new, ".")
        ClickHouseUtils.clickhouse_create_view(
            clickhouse_view_name=table_new,
            sql_statement=f"""
                 {Y} as Y, {T} as treatment,{cutbinstring}{no_cut_x_names_string} if(rand()/pow(2,32)<0.5,0,1) as if_test
          """,
            sql_table_name=table,
            primary_column="if_test",
            is_force_materialize=True,
        )
        # check if empty data for new table
        allcnt = int(
            sql_instance.sql(
                f"select count(*) as cnt from {table_new} where {FilterSchema(x_names_new)}"
            )["cnt"][0]
        )
        if allcnt == 0:
            print("Sample size is 0, check for null values")
            raise ValueError
        res = sql_instance.sql(
            f"select treatment from  {table_new} group by treatment limit 10"
        )["treatment"].tolist()
        treatments = set([float(i) for i in res])
        if treatments != {0, 1}:
            print("The value of T can only be 0 or 1")
            raise ValueError
        # compute ate before build tree
        train = sql_instance.sql(
            f"SELECT if(treatment=1,1,-1) as z, sum(Y) as sum,count(*) as cnt FROM {table_new} where if_test=0 group by if(treatment=1,1,-1)"
        )
        train = pd.DataFrame(train, columns=["z", "sum", "cnt"])
        train = train[["z", "sum", "cnt"]].astype(float)
        test = sql_instance.sql(
            f"SELECT if(treatment=1,1,-1) as z, sum(Y) as sum,count(*) as cnt FROM {table_new} where if_test=1 group by if(treatment=1,1,-1)"
        )
        test = pd.DataFrame(test, columns=["z", "sum", "cnt"])
        test = test[["z", "sum", "cnt"]].astype(float)
        dat_cnt = train["cnt"].sum()
        estData_cnt = test["cnt"].sum()
        data_all = pd.merge(train, test, on="z")
        ate = (
            data_all["z"]
            * (data_all["sum_x"] + data_all["sum_y"])
            / (data_all["cnt_x"] + data_all["cnt_y"])
        ).sum()
        estData_ate = ((test["z"] * (test["sum"])) / (test["cnt"])).sum()
        print(
            f"\t train data samples: {int(dat_cnt)},predict data samples: {int(estData_cnt)}"
        )
        # build tree
        print("****STEP4.  Build tree.")
        t1 = time.time()
        modelTree = CausalTreeclass(
            dfName=table_new,
            threshold=self.threshold,
            maxDepth=depth,
            whereCond="",
            nodePosition="L",
            impurity=0,
            father_split_feature="",
            father_split_feature_Categories=[],
            whereCond_new=[],
            nodesSet=[],
        )
        modelTree.get_global_values(dat_cnt, depth, featNames)
        print(
            "================================== start buildTree -- maxDepth: {maxDepth}, nodePosition: {nodePosition}==================================".format(
                maxDepth=depth, nodePosition="root"
            )
        )
        modelTree.buildTree()
        print(
            "============================================== build Tree Sucessfully====================================================="
        )
        result_list = []
        print("****STEP5.  Estimate CATE.")
        nodesSet = modelTree.nodesSet
        splitpoint_pdf_all = pd.DataFrame(
            [], columns=["featName", "featValue", "splitpoint", "ImpurityGain"]
        )
        for l in nodesSet:
            l.predictNode(ate=estData_ate, cnt=estData_cnt)  # TODO
            result = l.inferenceSet
            splitpoint_pdf_all = pd.concat(
                [splitpoint_pdf_all, l.splitpoint_pdf], axis=0
            )
            result_list.append(result)
        self.feature_importance = splitpoint_pdf_all.groupby(
            by=["featName"], as_index=False
        )["ImpurityGain"].sum()
        self.feature_importance = self.feature_importance.sort_values(
            by=["ImpurityGain"], ascending=False
        )
        self.feature_importance.columns = ["featName", "importance"]
        columns = [
            "TreeID",
            "level",
            "isLeaf",
            "whereCond",
            "whereCond_new",
            "ratio",
            "prediction",
            "treatedCount",
            "controlCount",
            "treatedAvg",
            "controlAvg",
            "treatedStd",
            "controlStd",
            "estPoint1",
            "std1",
            "pvalue1",
            "zValue1",
            "lowerCI1",
            "upperCI1",
            "estPoint2",
            "std2",
            "pvalue2",
            "zValue2",
            "lowerCI2",
            "upperCI2",
            "estPoint3",
            "std3",
            "pvalue3",
            "zValue3",
            "lowerCI3",
            "upperCI3",
            "estPoint4",
            "std4",
            "pvalue4",
            "zValue4",
            "lowerCI4",
            "upperCI4",
        ]
        result_df = pd.DataFrame(result_list, columns=columns)
        print(
            "============================================== compute Tree Sucessfully====================================================="
        )
        result_df["qvalue1"] = fdrcorrection(result_df["pvalue1"])[1]
        result_df["qvalue2"] = fdrcorrection(result_df["pvalue2"])[1]
        result_df["qvalue3"] = fdrcorrection(result_df["pvalue3"])[1]
        result_df["qvalue4"] = fdrcorrection(result_df["pvalue4"])[1]
        self.result_df = result_df
        # for continous variable, give the cut bins mapping
        # for example:  {"x_continuous_1_buckets":[4]} >>> {"x_continuous_1_buckets":'[89.62761587688087,95.04490061748047)'}
        if result_df.shape[0] > 1:
            Category_value_dicts = {}
            for i in x_names_new:
                try:
                    values_ = list(
                        modelTree.allsplitpoint_pdf[
                            modelTree.allsplitpoint_pdf["featName"] == i
                        ]["featValue"]
                    )
                    values_ = list(np.array(values_))
                    values_.sort()
                except:
                    print(values_)
                if i in cut_x_names_new:
                    cut_bins_dict = {}
                    bins = bins_dict[i.split("_buckets")[0]]
                    intevals = []
                    for j in range(len(bins) - 1):
                        intevals.append(
                            "["
                            + str(round(bins[j], 4))
                            + ","
                            + str(round(bins[j + 1], 4))
                            + ")"
                        )
                    for value in values_:
                        cut_bins_dict[value] = intevals[int(value) - 1]  # todo
                    Category_value_dicts[i] = cut_bins_dict
                else:
                    Category_value_dicts[i] = dict(zip(values_, values_))
            self.Category_value_dicts = Category_value_dicts
            whereCond_new_list = []
            for i in range(1, len(nodesSet)):
                whereCond_new_ = {}
                for whereCond_new in result_df.loc[i, "whereCond_new"]:
                    key = list(whereCond_new.keys())[0]
                    value = list(whereCond_new.values())[0]
                    value.sort()
                    Category_value_dict = Category_value_dicts[key]
                    whereCond_new_[key] = [Category_value_dict[j] for j in value]
                whereCond_new_list.append(whereCond_new_)
            result_df["whereCond_new"] = [""] + whereCond_new_list
        nodesSet[0].inferenceSet = dict(result_df.loc[0, :])
        for i in range(1, len(nodesSet)):
            nodesSet[i].inferenceSet = dict(
                result_df.loc[i, :]
            )  # inferenceSet: list >> dict
            nodesSet[i].whereCond_new = result_df.loc[i, "whereCond_new"]
            Category_value_dict = Category_value_dicts[nodesSet[i].father_split_feature]
            father_split_feature_Categories = nodesSet[
                i
            ].father_split_feature_Categories.copy()
            nodesSet[i].father_split_feature_Categories = [
                Category_value_dict[j] for j in father_split_feature_Categories
            ]
        # tree_structure
        ctRegressionTree = CTRegressionTree(
            tree=modelTree, total_count=test["cnt"].sum()
        )  # estData total cnt
        tree_structure = ctRegressionTree.get_decision_rules("root")
        self.tree_structure = tree_structure
        self.result_df = result_df
        self.estimate = list(self.result_df["prediction"])
        self.estimate_stderr = list(self.result_df["std1"])
        self.pvalue = list(self.result_df["pvalue1"])
        self.estimate_interval = np.array((self.result_df[["lowerCI1", "upperCI1"]]))
        ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=table_new)
        self.modelTree = modelTree
    def __add_nodes_edges(self, tree, dot=None):
        if dot is None:
            dot = Digraph(
                "g",
                filename="btree.gv",
                node_attr={
                    "shape": "record",
                    "height": ".1",
                    "width": "1",
                    "fontsize": "9",
                },
            )
            dot = Digraph()
            dot.node_attr.update(width="1", fontsize="9")
            node_id = int(tree["node_id"])
            node_id = f"Node : {node_id}"
            ate = f'CATE : {tree["tau_i"]}  (p_value:{tree["pvalues"][0]})'
            sample = f'sample size: control {tree["controlCount"]};  treatment {tree["treatedCount"]}'
            sample_ratio = f'sample_ratio:{tree["count_ratio"]}%'
            mean = (
                f'mean: control {tree["controlAvg"]};  treatment {tree["treatedAvg"]}'
            )
            dot.node(
                str(tree["node_id"]),
                "\n".join([node_id, sample_ratio, ate, sample, mean]),
            )
        for child in tree["children"]:
            if child:
                node_id = int(child["node_id"])
                node_id = f"Node : {node_id}"
                ate = f'CATE : {child["tau_i"]}  (p_value:{child["pvalues"][0]})'
                sample = f'sample size: control {child["controlCount"]};  treatment {child["treatedCount"]}'
                sample_ratio = f'sample_ratio:{child["count_ratio"]}%'
                mean = f'mean: control {child["controlAvg"]};  treatment {child["treatedAvg"]}'
                split_criterion = f'split_criterion:   {child["father_split_feature"]} in {child["father_split_feature_Categories"]}'
                dot.node(
                    str(child["node_id"]),
                    "\n".join(
                        [
                            node_id,
                            auto_wrap_text(split_criterion, 60),
                            sample_ratio,
                            ate,
                            sample,
                            mean,
                        ]
                    ),
                )
                dot.edge(str(tree["node_id"]), str(child["node_id"]))
                self.__add_nodes_edges(child, dot)
        return dot
    def treeplot(self):
        tree_structure = self.tree_structure.get_dict()
        dot = self.__add_nodes_edges(tree_structure)
        return dot
    def hte_plot(self):
        # toc curve
        result_df = self.result_df
        toc_data_df = result_df[result_df["isLeaf"] == True][
            ["prediction", "treatedCount", "controlCount", "treatedAvg", "controlAvg"]
        ].sort_values(by="prediction", ascending=False)
        toc_data_df.reset_index(drop=True, inplace=True)
        toc_data_df["cnt"] = toc_data_df["treatedCount"] + toc_data_df["controlCount"]
        toc_data_df["treatedsum"] = (
            toc_data_df["treatedCount"] * toc_data_df["treatedAvg"]
        )
        toc_data_df["controlsum"] = (
            toc_data_df["controlCount"] * toc_data_df["controlAvg"]
        )
        cnt = toc_data_df["cnt"].sum()
        toc_data_df["treatedcumsum"] = toc_data_df["treatedsum"].cumsum()
        toc_data_df["controlcumsum"] = toc_data_df["controlsum"].cumsum()
        toc_data_df["treatedcumcnt"] = toc_data_df["treatedCount"].cumsum()
        toc_data_df["controlcumcnt"] = toc_data_df["controlCount"].cumsum()
        toc_data_df["ratio"] = toc_data_df["cnt"] / cnt
        toc_data_df["ratio_sum"] = toc_data_df["ratio"].cumsum()
        toc_data_df["toc"] = (
            toc_data_df["treatedcumsum"] / toc_data_df["treatedcumcnt"]
            - toc_data_df["controlcumsum"] / toc_data_df["controlcumcnt"]
        )
        toc_data_df["qini"] = (
            toc_data_df["treatedcumsum"]
            - toc_data_df["controlcumsum"]
            / toc_data_df["controlcumcnt"]
            * toc_data_df["treatedcumcnt"]
        )
        toc_data_df["order"] = toc_data_df.index + 1
        toc_data_df["x"] = toc_data_df["ratio_sum"]
        ate = list(toc_data_df["toc"])[-1]
        toc_data_df["toc1"] = ate
        toc_data_df["qini1"] = toc_data_df["x"] * list(toc_data_df["qini"])[-1]
        toc_data_df
        toc_data = toc_data_df[["x", "toc", "toc1", "qini", "qini1"]]
        toc_data.columns = [
            "ratio",
            "toc_tree",
            "toc_random",
            "qini_tree",
            "qini_random",
        ]
        toc_data = toc_data.sort_values(by=["ratio"])
        result = toc_data
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=(12, 4.8))
        ax1.plot(result["ratio"], result["toc_tree"], label="CausalTree")
        ax2.plot(
            [0] + list(result["ratio"]),
            [0] + list(result["qini_tree"]),
            label="CausalTree",
        )
        ax1.plot(result["ratio"], result["toc_random"], label="Random Model")
        ax2.plot(
            [0] + list(result["ratio"]),
            [0] + list(result["qini_random"]),
            label="Random Model",
        )
        ax1.set_title("Cumulative Lift Curve")
        ax1.legend()
        ax2.set_title("Cumulative Gain Curve")
        ax2.legend()
        fig.suptitle("CausalTree Lift and Gain Curves")
        plt.show()
[docs]    def effect(self, df, keep_col="*"):
        """
        Calculate the individual treatment effect.
        Parameters
        ----------
        df : DataFrame
            The input data.
        keep_col : str, optional
            The columns to keep. Defaults to '*'.
        Returns
        -------
        DataFrame
            The result data.
        """
        table_input = df_2_table(df)
        cutbinstring = self.__cutbinstring
        table_tmp = f"{table_input}_{int(time.time())}_foreffect_2_clickhouse"
        table_tmp1 = table_tmp + "1"
        ClickHouseUtils.clickhouse_create_view(
            clickhouse_view_name=table_tmp,
            sql_statement=f"""
                  *,{cutbinstring}1 as index
            """,
            sql_table_name=table_input,
            primary_column="index",
            is_force_materialize=True,
        )
        leaf_effect = np.array(
            self.result_df[self.result_df["isLeaf"] == True][
                ["whereCond", "prediction"]
            ]
        )
        string = " ".join([f"when True {x[0]} then {x[1]} " for x in leaf_effect])
        ClickHouseUtils.clickhouse_create_view(
            clickhouse_view_name=table_tmp1,
            sql_statement=f"""
                 {keep_col},
                case 
                    {string}
                else 4294967295
                end as effect
                
            """,
            sql_table_name=table_tmp,
            primary_column="effect",
            is_force_materialize=True,
        )
        ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=table_tmp)
        return readClickHouse(table_tmp1)  
def save_model(model, file):
    # 将类实例序列化并保存到本地文件
    # print("file must be of type 'pkl'. \nExample: save_model(model,'file_name.pkl')")
    with open(file, "wb") as f:
        pickle.dump(model, f)
        print(f"model{model} is stored at '{file}'")
def load_model(file):
    # 从本地文件中加载并反序列化类实例
    with open(file, "rb") as f:
        model = pickle.load(f)
        return model
def save_model2ch(hte):
    cutbin_string = hte._CausalTree__cutbinstring[:-1]
    leaf_effect = np.array(
        hte.result_df[hte.result_df["isLeaf"] == True][["whereCond", "prediction"]]
    )
    effect_string = " ".join([f"when True {x[0]} then {x[1]} " for x in leaf_effect])
    model_table = f"tmp_{int(time.time())}"
    ClickHouseUtils.clickhouse_create_view(
        clickhouse_view_name=model_table,
        sql_statement=f"""
          '{cutbin_string}' as cutbin_string, '{effect_string}' as effect_string
    """,
        sql_table_name=hte.table,
        sql_limit=1,
        is_force_materialize=True,
    )
    return readClickHouse(model_table)
[docs]class CausalForest:
    """
    This class implements the Causal Forest method for causal inference.
    Parameters
    ----------
    depth : int, default=7
        The maximum depth of the tree.
    min_node_size : int, default=-1
        The minimum node size.
    mtry : int, default=3
        The number of variables randomly sampled as candidates at each split.
    num_trees : int, default=10
        The number of trees to grow in the forest.
    sample_fraction : float, default=0.7
        The fraction of observations to consider when fitting the forest.
    weight_index : str, default=''
        The weight index.
    honesty : bool, default=False
        Whether to use honesty when fitting the forest.
    honesty_fraction : float, default=0.5
        The fraction of observations to use for determining splits if honesty is used.
    quantile_num : int, default=50
        The number of quantiles.
    Methods
    -------
    fit(Y, T, X, df):
        Fit the Causal Forest model to the input data.
    effect(input_df=None, X=[]):
        Estimate the causal effect using the fitted model.
    Example
    -------
    .. code-block:: python
        import fast_causal_inference
        from fast_causal_inference.dataframe.uplift import *
        Y='y'
        T='treatment'
        table = 'test_data_small'
        X = ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
        df = fast_causal_inference.readClickHouse(table)
        df_train, df_test = df.split(0.5)
        from fast_causal_inference.dataframe.uplift import CausalForest
        model = CausalForest(depth=7, min_node_size=-1, mtry=3, num_trees=10, sample_fraction=0.7)
        model.fit(Y, T, X, df_train)
    """
    def __init__(
        self,
        depth=7,
        min_node_size=-1,
        mtry=3,
        num_trees=10,
        sample_fraction=0.7,
        weight_index="",
        honesty=False,
        honesty_fraction=0.5,
        quantile_num=50,
    ):
        self.depth = depth
        self.min_node_size = min_node_size
        self.mtry = mtry
        self.num_trees = num_trees
        self.sample_fraction = sample_fraction
        self.weight_index = weight_index
        self.honesty = 0
        if honesty == True:
            self.honesty = 1
        self.honesty_fraction = honesty_fraction
        self.quantile_num = quantile_num
        self.quantile_num = max(1, min(100, self.quantile_num))
[docs]    def fit(self, Y, T, X, df):
        """
        Fit the Causal Forest model to the input data.
        Parameters
        ----------
        Y : str
            The outcome variable.
        T : str
            The treatment variable.
        X : list
            The numeric covariates. Strings are not supported. ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
        df : DataFrame
                    The input dataframe.
        Returns
        -------
        None
        """
        self.table = df.getTableName()
        self.origin_table = df.getTableName()
        self.Y = Y
        self.T = T
        self.X = X
        self.len_x = len(X)
        self.ts = current_time_ms = int(time.time() * 1000)
        # create model df.getTableName()
        self.model_table = "model_" + self.table + str(self.ts)
        sql_instance = SqlGateWayConn.create_default_conn()
        count = sql_instance.sql("select count() as cnt from " + self.table)
        if isinstance(count, str):
            print(count)
            return
        self.table_count = count["cnt"][0]
        self.mtry = min(30, self.mtry)
        self.num_trees = min(200, self.num_trees)
        calc_min_node_size = int(max(int(self.table_count) / 128, 1))
        self.min_node_size = max(self.min_node_size, calc_min_node_size)
        # insert into {self.model_table}
        self.config = f""" select '{{"max_centroids":1024,"max_unmerged":2048,"honesty":{self.honesty},"honesty_fraction":{self.honesty_fraction}, "quantile_size":{self.quantile_num}, "weight_index":2, "outcome_index":0, "treatment_index":1, "min_node_size":{self.min_node_size}, "sample_fraction":{self.sample_fraction}, "mtry":{self.mtry}, "num_trees":{self.num_trees}}}' as model, {self.ts} as ver"""
        ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.model_table)
        ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.model_table)
        ClickHouseUtils.clickhouse_create_view(
            clickhouse_view_name=self.model_table,
            sql_statement=self.config,
            is_sql_complete=True,
            sql_table_name=self.table,
            primary_column="ver",
            is_force_materialize=False,
        )
        sql_instance.sql(self.config)
        if self.weight_index == "":
            self.weight_index = "1 / " + str(self.table_count)
        self.xs = ",".join(X)
        self.init_sql = f"""
        insert into {self.model_table} (model, ver)  
        WITH
        (select max(ver) from {self.model_table}) as ver0,
            (
                SELECT model
                FROM {self.model_table}
                WHERE ver = ver0 limit 1
            ) AS model
        SELECT CausalForest(model)({Y}, {T}, {self.weight_index}, {self.xs}), ver0 + 1
        FROM {self.table}
        """
        res = sql_instance.sql(self.init_sql)
        self.train_sql = f"""
        
        insert into {self.model_table} (model, ver)  
        WITH
        (select max(ver) from {self.model_table}) as ver0,
            (
                SELECT model
                FROM {self.model_table}
                WHERE ver = ver0 limit 1
            ) AS model,
        (
        SELECT CausalForest(model)({Y}, {T}, {self.weight_index}, {self.xs})
        FROM {self.table}
        ) as calcnumerdenom,
        (
        SELECT CausalForest(calcnumerdenom)({Y}, {T}, {self.weight_index}, {self.xs})
        FROM {self.table}
        ) as split_pre
        SELECT CausalForest(split_pre)({Y}, {T}, {self.weight_index}, {self.xs}), ver0 + 1
        FROM {self.table}
        """
        for i in range(self.depth):
            print("deep " + str(i + 1) + " train over")
            res = sql_instance.execute(self.train_sql)
            if isinstance(res, str) == True and res.find("train over") != -1:
                print("----------训练结束----------")
                break 
[docs]    def effect(self, df=None, X=[]):
        """
        Estimate the causal effect using the fitted model.
        Parameters
        ----------
        df : DataFrame, default=None
            The input dataframe for which to estimate the causal effect. If None, use the dataframe from the fit method.
        X : list, default=[]
            The covariates to use when estimating the causal effect. ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
        Returns
        -------
        DataFrame
            The output dataframe with the estimated causal effect.
        Example
        -------
        .. code-block:: python
            df_test_effect_cf = model.effect(df=df_test, X=['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2'])
            df_train_effect_cf = model.effect(df=df_train, X=['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2'])
            lift_train = get_lift_gain("effect", Y, T, df_test_effect_cf,discrete_treatment=True, K=100)
            lift_test = get_lift_gain("effect", Y, T, df_train_effect_cf,discrete_treatment=True, K=100)
            print(lift_train,lift_test)
            hte_plot([lift_train,lift_test],labels=['train','test'])
        """
        # Your code here
        if X != []:
            len_x = len(X)
            if len_x != self.len_x:
                print("The number of x is not equal to the number of x in the model")
                return
            self.xs = ",".join(X)
        if df != None:
            self.effect_table = df.getTableName()
        else:
            self.effect_table = self.table
        self.output_table = DataFrame.createTableName()
        ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.output_table)
        self.predict_sql = f"""
            WITH
                             (
                                 SELECT max(ver)
                                 FROM {self.model_table}
                             ) AS ver0,
                             (
                                 SELECT model
                                 FROM {self.model_table}
                                 WHERE ver = ver0 limit 1
                             ) AS pure,
                         (SELECT CausalForestPredict(pure)({self.Y}, 0, {self.weight_index}, {self.xs}) FROM {self.origin_table}) as model,
                         (SELECT CausalForestPredictState(model)(number) FROM numbers(0)) as predict_model
                         select *, evalMLMethod(predict_model, 0, {self.weight_index}, {self.xs}) as effect
            FROM {self.effect_table} 
        """
        ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.output_table)
        ClickHouseUtils.clickhouse_drop_view(
            clickhouse_view_name=self.output_table + "_local"
        )
        ClickHouseUtils.clickhouse_create_view(
            clickhouse_view_name=self.output_table,
            sql_statement=self.predict_sql,
            is_sql_complete=True,
            sql_table_name=self.model_table,
            primary_column="effect",
            is_use_local=False,
        )
        print("succ")
        return readClickHouse(self.output_table)  
class LinearDML:
    def __init__(
        self,
        model_y="ols",
        model_t="ols",
        fit_cate_intercept=True,
        discrete_treatment=True,
        categories=[0, 1],
        cv=3,
        treatment_featurizer="",
    ):
        self.model_y = model_y
        self.model_t = model_t
        self.treatment_featurizer = ""
        self.effect_ph = ""
        self.marginal_effect_ph = ""
        self.cv = cv
        if treatment_featurizer != "":
            self.treatment_featurizer = treatment_featurizer[0]
            self.effect_ph = treatment_featurizer[1]
            self.marginal_effect_ph = treatment_featurizer[2]
        self.sql_instance = SqlGateWayConn.create_default_conn()
    def fit(self, df, Y, T, X, W=""):
        self.table = df.getTableName()
        self.Y = Y
        self.T = T
        self.X = X
        self.dml_sql = self.get_dml_sql(
            self.table,
            Y,
            T,
            X,
            W,
            self.model_y,
            self.model_t,
            self.cv,
            self.treatment_featurizer,
        )
        self.forward_sql = self.sql_instance.sql(
            sql=self.dml_sql, is_calcite_parse=True
        )
        self.result = self.sql_instance.sql(self.dml_sql)
        if isinstance(self.result, str) and self.result.find("error") != -1:
            self.success = False
            self.ols = self.result
            return
        self.ols = Ols(self.result["final_model"][0])
        if isinstance(self.ols, Ols) == True:
            self.success = True
        else:
            self.success = False
    def get_dml_sql(
        self, table, Y, T, X, W, model_y, model_t, cv, treatment_featurizer
    ):
        sql = "select linearDML("
        sql += Y + "," + T + "," + X + ","
        if W.strip() != "":
            sql += W + ","
        sql += (
            "model_y='"
            + model_y
            + "',"
            + "model_t='"
            + model_t
            + "',"
            + "cv="
            + str(cv)
        )
        sql += " ) from " + table
        sql = sql.replace("ols", "Ols")
        return sql
    def __str__(self):
        return str(self.ols)
    def summary(self):
        if not self.success:
            return str(self.ols)
        return self.ols.get_dml_summary()
    def exchange_dml_sql(self, sql, use_interval=False):
        pos = sql.find("final_model")
        if pos == -1:
            raise Exception("Logical Error: final_model not found in sql")
        sql = sql[0 : pos + len("final_model")]
        pos = sql.rfind("Ols")
        if pos == -1:
            raise Exception("Logical Error: Ols not found in sql")
        if use_interval:
            sql = sql[0:pos] + "OlsIntervalState" + sql[pos + len("Ols") :]
        else:
            sql = sql[0:pos] + "OlsState" + sql[pos + len("Ols") :]
        return sql
    def effect(self, df, X="", T0=0, T1=1):
        table_predict = df.getTableName()
        table_output = DataFrame.createTableName()
        if table_predict == "":
            table_predict = self.table
        if self.success == False:
            return str(self.ols)
        if not X:
            X = self.X
        sql = self.exchange_dml_sql(self.forward_sql) + "\n"
        X = X.replace("+", ",")
        X1 = X.split(",")
        X1 = [x + " as " + x for x in X1]
        sql += "select "
        if table_output == "":
            sql += self.Y + " as Y," + self.T + " as T,"
        sql += (
            X
            + ", evalMLMethod(final_model, "
            + X
            + ", "
            + str(T1 - T0)
            + ") as predict from "
            + table_predict
        )
        if table_output != "":
            if X.count("+") >= 30 or X.count(",") >= 30:
                print("The number of x exceeds the limit 30")
            ClickHouseUtils.clickhouse_create_view(
                clickhouse_view_name=table_output,
                sql_statement=sql,
                primary_column="predict",
                is_force_materialize=True,
                is_sql_complete=True,
                sql_table_name=self.table,
                is_use_local=True,
            )
        return readClickHouse(table_output)
    def ate(self, X="", T0=0, T1=1):
        if not self.success:
            return str(self.ols)
        if not X:
            X = self.X
        sql = self.exchange_dml_sql(self.forward_sql) + "\n"
        X = X.replace("+", ",")
        sql += (
            "select avg(evalMLMethod(final_model, "
            + X
            + ", "
            + str(T1 - T0)
            + ")) from "
            + self.table
        )
        sql += " limit 100"
        t = self.sql_instance.sql(sql)
        return t
    def effect_interval(self, df, X="", T0=0, T1=1, alpha=0.05):
        table_output = DataFrame.createTableName()
        if not self.success:
            return str(self.ols)
        if not X:
            X = self.X
        sql = self.exchange_dml_sql(self.forward_sql, use_interval=True) + "\n"
        X = X.replace("+", ",")
        X1 = X.split(",")
        X1 = [x + " as " + x for x in X1]
        sql += (
            "select "
            + ", ".join(X1)
            + ", evalMLMethod(final_model,'confidence',"
            + str(1 - alpha)
            + ", "
            + X
            + ", "
            + str(T1 - T0)
            + ") as predict from "
            + df.getTableName()
        )
        ClickHouseUtils.clickhouse_create_view(
            clickhouse_view_name=table_output,
            sql_statement=sql,
            primary_column="predict",
            is_force_materialize=True,
            is_sql_complete=True,
            sql_table_name=self.table,
            is_use_local=True,
        )
        return readClickHouse(table_output)
    def ate_interval(self, X="", T0=0, T1=1, alpha=0.05):
        if not self.success:
            return str(self.ols)
        if not X:
            X = self.X
        sql = self.exchange_dml_sql(self.forward_sql, use_interval=True) + "\n"
        X = X.replace("+", ",")
        Xs = X.split(",")
        for i in range(len(Xs)):
            Xs[i] = "avg(" + Xs[i] + ")"
        X = ",".join(Xs)
        sql += (
            "select evalMLMethod(final_model,'confidence',"
            + str(1 - alpha)
            + ", "
            + X
            + ", "
            + str(T1 - T0)
            + ") from "
            + self.table
        )
        sql += " limit 100"
        t = self.sql_instance.sql(sql)
        if str(t).find("DB::Exception") != -1:
            return t
        s = "mean_point\tci_mean_lower\tci_mean_upper\t\n"
        t = t.iloc[0, 0]
        t = t[1:-1]
        t = t.split(",")
        for i in range(len(t)):
            s += str(round(float(t[i]), 10)) + "\t"
        return s
    def get_sql(self, X):
        sql = self.exchange_dml_sql(self.forward_sql) + "\n"
        x_with_space = X.split("+")
        x_with_space.append("1")
        model_arguments = ""
        for x in x_with_space:
            for y in self.treatment_featurizer:
                model_arguments += x + "*" + y + ","
        model_arguments = "(False)(" + self.Y + "," + model_arguments[0:-1] + ")"
        last_olsstate_index = sql.rfind("OlsState")
        last_from_index = sql.rfind("FROM")
        sql = (
            sql[: last_olsstate_index + len("OlsState")]
            + model_arguments
            + " "
            + sql[last_from_index:]
        )
        tmp_eval = "evalMLMethod(final_model"
        for x in X.split("+"):
            for y in self.marginal_effect_ph:
                tmp_eval += "," + str(x) + "*" + str(y)
        for y in self.marginal_effect_ph:
            tmp_eval += "," + str(y)
        tmp_eval += ")"
        evals = []
        for i in range(0, len(self.marginal_effect_ph) + 1):
            evals.append(tmp_eval)
        for i in range(1, len(self.marginal_effect_ph) + 1):
            for j in range(0, len(evals)):
                if i != j:
                    evals[j] = evals[j].replace("@PH" + str(i), "0")
                else:
                    evals[j] = evals[j].replace("@PH" + str(i), "1")
        sql_const = sql + " select "
        sql_effect = sql_const
        sql_ate = sql_const + " avg( "
        for i in range(1, len(self.marginal_effect_ph) + 1):
            sql_const += evals[i] + " - " + evals[0] + " as predict" + str(i) + ","
            sql_effect += evals[i] + " - " + evals[0] + " +"
            sql_ate += evals[i] + " - " + evals[0] + " +"
        sql_const = sql_const[:-1] + " from " + self.table
        sql_effect = sql_effect[:-1] + " as predict from " + self.table
        sql_ate = sql_ate[:-1] + ") as predict from " + self.table
        return [sql_const, sql_effect, sql_ate]
    def const_marginal_effect(self, X="", table_output=""):
        if not self.success:
            return str(self.ols)
        if not X:
            X = self.X
        if self.marginal_effect_ph == "":
            return "Error: treatment featurizer is empty!"
        sql = self.get_sql(X)[0]
        if table_output == "":
            sql += " limit 100"
        if table_output != "":
            ClickHouseUtils.clickhouse_create_view(
                clickhouse_view_name=table_output,
                sql_statement=sql,
                primary_column="predict1",
                is_force_materialize=True,
                is_sql_complete=True,
                sql_table_name=self.table,
            )
            return
        t = self.sql_instance.sql(sql)
        return t
    def marginal_effect(self, X="", table_output=""):
        if not self.success:
            return str(self.ols)
        if not X:
            X = self.X
        if not self.marginal_effect_ph:
            return "Error: treatment featurizer is empty!"
        sql = self.get_sql(X)[1]
        if table_output == "":
            sql += " limit 100"
        if table_output != "":
            if sql.count("+") >= 30 or sql.count(",") >= 30:
                print("The number of x exceeds the limit 40")
            ClickHouseUtils.clickhouse_create_view(
                clickhouse_view_name=table_output,
                sql_statement=sql,
                primary_column="predict",
                is_force_materialize=True,
                is_sql_complete=True,
                sql_table_name=self.table,
            )
        t = self.sql_instance.sql(sql)
        return t
    def marginal_ate(self, X="", table_output=""):
        if not self.success:
            return str(self.ols)
        if not X:
            X = self.X
        if not self.marginal_effect_ph:
            return "Error: treatment featurizer is empty!"
        sql = self.get_sql(X)[2]
        if not table_output:
            sql += " limit 100"
        if table_output:
            ClickHouseUtils.clickhouse_create_view(
                clickhouse_view_name=table_output,
                sql_statement=sql,
                primary_column="predict",
                is_force_materialize=True,
                is_sql_complete=True,
                sql_table_name=self.table,
            )
            return
        t = self.sql_instance.sql(sql)
        return t
def load_chmodel_predict(df_model, df, keep_col="*"):
    table_input = df_2_table(df)
    model_table = df_2_table(df_model)
    sql_instance = SqlGateWayConn.create_default_conn()
    tmp = sql_instance.sql(
        f"select cutbin_string,effect_string from {model_table} limit 1"
    )
    cutbinstring = tmp["cutbin_string"][0]
    effect_string = tmp["effect_string"][0]
    table_tmp = f"{table_input}_{int(time.time())}_foreffect_2_clickhouse"
    table_output = f"tmp_{int(time.time())}"
    ClickHouseUtils.clickhouse_create_view(
        clickhouse_view_name=table_tmp,
        sql_statement=f"""
              *,{cutbinstring},1 as index
        """,
        sql_table_name=table_input,
        is_force_materialize=True,
    )
    ClickHouseUtils.clickhouse_create_view(
        clickhouse_view_name=table_output,
        sql_statement=f"""
             {keep_col},
            case 
                {effect_string}
            else 4294967295
            end as effect
        """,
        sql_table_name=table_tmp,
        is_force_materialize=True,
    )
    ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=table_tmp)
    return table_2_df(table_output)