import pickle
import warnings
from fast_causal_inference.util import (
create_sql_instance,
)
from fast_causal_inference.dataframe.dataframe import DataFrame, readClickHouse
from fast_causal_inference.dataframe.df_base import df_2_table, table_2_df
from fast_causal_inference.lib.ols import *
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from graphviz import Digraph
import time
import math
from fast_causal_inference.dataframe.dataframe import DataFrame, readClickHouse
from fast_causal_inference.lib.ols import Ols
from fast_causal_inference.dataframe.regression import check_table, check_columns
from fast_causal_inference.util import SqlGateWayConn, ClickHouseUtils
from fast_causal_inference.lib.causaltree import (
FeatNames,
FilterSchema,
CausalTreeclass,
fdrcorrection,
CTRegressionTree,
auto_wrap_text,
)
def df_2_table(df):
new_df = df.materializedView()
return new_df.getTableName()
def table_2_df(table):
return readClickHouse(table)
from statsmodels.stats.multitest import fdrcorrection
import pickle
import warnings
warnings.filterwarnings("ignore")
# Define a class to store the results of Lift Gain Curve
class LiftGainCurveResult:
def __init__(self, data):
# Initialize the result as a DataFrame with specified column names
self.result = pd.DataFrame(
data, columns=["ratio", "lift", "gain", "ate", "ramdom_gain"]
)
def __str__(self):
# Return the string representation of the result
return str(self.result)
def summary(self):
# Print the summary of the result
print(self.result)
def get_result(self):
# Return the result
return self.result
# Function to calculate lift gain
[docs]def get_lift_gain(ITE, Y, T, df, normalize=True, K=1000, discrete_treatment=True):
"""
Calculate the uplift & gain.
Parameters
----------
ITE : str
The Individual Treatment Effect column.
Y : str
The outcome variable column.
T : str
The treatment variable column.
df : DataFrame
The input data.
normalize : bool, optional
Whether to normalize the result, default is True.
K : int, optional
The number of bins for discretization, default is 1000.
discrete_treatment : bool, optional
Whether the treatment is discrete, default is True.
Returns
-------
LiftGainCurveResult
An object containing the result of the uplift & gain calculation.
Example
-------
.. code-block:: python
import fast_causal_inference
from fast_causal_inference.dataframe.uplift import *
Y='y'
T='treatment'
table = 'test_data_small'
X = 'x1+x2+x3+x4+x5+x_long_tail1+x_long_tail2'
needcut_X = 'x1+x2+x3+x4+x5+x_long_tail1+x_long_tail2'
df = readClickHouse(table)
df_train, df_test = df.split(0.5)
hte = CausalTree(depth = 3,min_sample_ratio_leaf=0.001)
hte.fit(Y,T,X,needcut_X,df_train)
df_train_pred = hte.effect(df=df_train,keep_col='*')
df_test_pred = hte.effect(df=df_test,keep_col='*')
lift_train = get_lift_gain("effect", Y, T, df_train_pred,discrete_treatment=True, K=100)
lift_test = get_lift_gain("effect", Y, T, df_test_pred,discrete_treatment=True, K=100)
print(lift_train,lift_test)
hte_plot([lift_train,lift_test],labels=['train','test'])
# auuc: 0.6624369283393814
# auuc: 0.6532554148698826
# ratio lift gain ate ramdom_gain
# 0 0.009990 2.164241 0.021621 1.0 0.009990
# 1 0.019980 2.131245 0.042582 1.0 0.019980
# 2 0.029970 2.056440 0.061632 1.0 0.029970
# 3 0.039960 2.177768 0.087024 1.0 0.039960
# 4 0.049950 2.175329 0.108658 1.0 0.049950
# .. ... ... ... ... ...
# 95 0.959241 1.015223 0.973843 1.0 0.959241
# 96 0.969431 1.010023 0.979147 1.0 0.969431
# 97 0.979620 1.006843 0.986324 1.0 0.979620
# 98 0.989810 1.003508 0.993283 1.0 0.989810
# 99 1.000000 1.000000 1.000000 1.0 1.000000
# [100 rows x 5 columns] ratio lift gain ate ramdom_gain
# 0 0.009810 1.948220 0.019112 1.0 0.009810
# 1 0.019620 2.221654 0.043588 1.0 0.019620
# 2 0.029429 2.419752 0.071212 1.0 0.029429
# 3 0.039239 2.288460 0.089797 1.0 0.039239
# 4 0.049049 2.343432 0.114943 1.0 0.049049
# .. ... ... ... ... ...
# 95 0.959960 1.014897 0.974260 1.0 0.959960
# 96 0.969970 1.011624 0.981245 1.0 0.969970
# 97 0.979980 1.009358 0.989150 1.0 0.979980
# 98 0.989990 1.006340 0.996267 1.0 0.989990
# 99 1.000000 1.000000 1.000000 1.0 1.000000
# [100 rows x 5 columns]
"""
# Construct the SQL query
sql = (
"select lift("
+ str(ITE)
+ ","
+ str(Y)
+ ","
+ str(T)
+ ","
+ str(K)
+ ","
+ str(discrete_treatment).lower()
+ ") from "
+ str(df_2_table(df))
+ " limit 100000"
)
# Create an SQL instance
sql_instance = SqlGateWayConn.create_default_conn()
# Execute the SQL query and get the result
result = sql_instance.sql(sql)
# Select specific columns from the result
result = result[["ratio", "lift", "gain", "ate", "ramdom_gain"]]
# Replace 'nan' with np.nan
result = result.replace("nan", np.nan)
# Convert the data type to float
result = result.astype(float)
# Drop rows with missing values
result = result.dropna()
# Normalize the result if required
if normalize:
result = result.div(np.abs(result.iloc[-1, :]), axis=1)
# Calculate AUUC
auuc = result["gain"].sum() / result.shape[0]
print("auuc:", auuc)
# Return the result as a LiftGainCurveResult object
return LiftGainCurveResult(result)
# Function to plot HTE
[docs]def hte_plot(results, labels=[]):
"""
Plot the uplift & gain.
Parameters
----------
results : list
A list of LiftGainCurveResult objects to be plotted.
labels : list, optional
A list of labels for the results, default is an empty list.
Returns
-------
None
This function will display a plot.
"""
# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=(12, 4.8))
# Generate labels if not provided
if len(labels) == 0:
labels = [f"model_{i + 1}" for i in range(len(results))]
# Plot the results
for i in range(len(results)):
result = results[i].get_result()
auuc = round(result["gain"].sum() / result.shape[0], 2)
ax1.plot(result["ratio"], result["lift"], label=labels[i])
ax2.plot(
[0] + list(result["ratio"]),
[0] + list(result["gain"]),
label=labels[i] + f"(auuc:{auuc})",
)
# Plot the ATE and random gain
ax1.plot(result["ratio"], result["ate"])
ax2.plot([0] + list(result["ratio"]), [0] + list(result["ramdom_gain"]))
# Set the titles and legends
ax1.set_title("Cumulative Lift Curve")
ax1.legend()
ax2.set_title("Cumulative Gain Curve")
ax2.legend()
# Set the title for the figure
fig.suptitle("Lift and Gain Curves")
# Display the figure
plt.show()
# global sql_instance
# sql_instance = ais.create()
def SelectSchema(schema):
return schema
def FeatNames(schema):
if "," in schema:
schemaArray = schema.split(",")
else:
schemaArray = [schema]
return schemaArray
def FilterSchema(schemaArray):
tmp = [i + " is not null" for i in schemaArray]
return " and ".join(tmp)
# CausalTree class
class CausalTreeclass:
def __init__(
self,
dfName="dat",
threshold=0.01,
maxDepth=3,
whereCond="",
nodePosition="L",
impurity=0,
father_split_feature="",
father_split_feature_Categories=[],
whereCond_new=[],
nodesSet=[],
):
self.dfName = dfName
self.threshold = threshold
self.maxDepth = maxDepth
self.whereCond = whereCond
self.nodePosition = nodePosition
self.impurity = impurity
self.father_split_feature = father_split_feature
self.father_split_feature_Categories = father_split_feature_Categories
self.whereCond_new = whereCond_new
self.nodesSet = nodesSet
self.leftNode = 0
self.rightNode = 0
self.nodeSize = 0
self.controlcount = 0
self.treatcount = 0
self.splitFeat = ""
self.splitIndex = 0
self.splitpoint = ""
self.prediction = 0
self.maxImpurGain = 0
self.isLeaforNot = False
self.splitpoint_pdf = pd.DataFrame()
self.allsplitpoint_pdf = pd.DataFrame()
self.featvalues_dict = dict()
self.inferenceSet = [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
]
self.dat_cnt = 0
self.depth = 3
self.sql_instance = create_sql_instance()
def get_global_values(self, dat_cnt, depth, featNames):
self.dat_cnt = dat_cnt
self.depth = depth
self.featNames = featNames
def compute_df(self):
threshold = self.threshold
if self.maxDepth == 0:
self.isLeaforNot = True
print("Reach the maxDepth, stop as a leaf node")
return
if self.whereCond != "":
whereCond_ = "AND" + self.whereCond
else:
whereCond_ = ""
dfName = self.dfName
sql_instance = self.sql_instance
row = sql_instance.sql(
f"""SELECT treatment as T,count(Y) as cnt,avg(Y) as mean,varSamp(Y) as var FROM {dfName}
WHERE if_test = 0 {whereCond_}
group by treatment order by treatment """
)
self.controlcount = int(row["cnt"][0])
self.treatcount = int(row["cnt"][1])
self.nodeSize = self.controlcount + self.treatcount
if self.nodeSize / self.dat_cnt < threshold:
self.isLeaforNot = True
print("sample size is too small, stop as a leaf node")
return
if self.nodeSize == 0:
self.isLeaforNot = True
print(" sample size = 0, stop as a leaf node")
return
if float(row["var"][0]) == 0 and float(row["var"][1]) == 0:
self.isLeaforNot = True
print(" var = 0, stop as a leaf node ")
return
t1 = time.time()
cols = [
"featName",
"featValue",
"cnt1",
"y1",
"y1_square",
"cnt0",
"y0",
"y0_square",
]
splitpoint_data_pdf = pd.DataFrame([], columns=cols)
featNames = self.featNames
for i in range(len(featNames) // 100 + 1):
x_list = featNames[i * 100 : min((i + 1) * 100, len(featNames))]
x_list_string1 = "'" + "', '".join(x_list) + "'"
x_list_string2 = ",".join(x_list)
try:
res = sql_instance.sql(
f"""SELECT arrayJoin(GroupSet({x_list_string1})(Y, treatment, {x_list_string2})) as a
FROM {dfName} WHERE if_test = 0 {whereCond_} order by a.1,a.3,a.2 asc limit 100000"""
)["a"].tolist()
data_list = [
i.replace("[", "").replace("]", "").replace(" ", "").split(",")
for i in res
]
df = pd.DataFrame(
data_list,
columns=[
"featName",
"treatment",
"featValue",
"cnt",
"y",
"y_square",
],
)
except:
print(
sql_instance.sql(
f"""SELECT arrayJoin(GroupSet({x_list_string1})(Y, treatment, {x_list_string2})) as a
FROM {dfName} WHERE if_test = 0 {whereCond_} order by a.1,a.3,a.2 asc"""
)
)
df[["treatment", "cnt", "y", "y_square"]] = df[
["treatment", "cnt", "y", "y_square"]
].astype(float)
df[["featName", "featValue"]] = df[["featName", "featValue"]].astype(str)
df["featName"] = [i.replace("'", "") for i in list(df["featName"])]
df0 = df[df["treatment"] == 0].drop("treatment", axis=1)
df1 = df[df["treatment"] == 1].drop("treatment", axis=1)
df_new = pd.merge(df1, df0, on=["featName", "featValue"])
df_new.columns = cols
splitpoint_data_pdf = pd.concat([splitpoint_data_pdf, df_new], axis=0)
splitpoint_data_pdf["tau"] = (
splitpoint_data_pdf["y1"] / splitpoint_data_pdf["cnt1"]
- splitpoint_data_pdf["y0"] / splitpoint_data_pdf["cnt0"]
)
splitpoint_data_pdf = splitpoint_data_pdf.sort_values(
by=["featName", "tau"], ascending=False
)
tmp = splitpoint_data_pdf.groupby(by=["featName"], as_index=False)[
"featValue"
].count()
tmp.columns = ["featName", "featpoint_num"]
splitpoint_data_pdf = pd.merge(splitpoint_data_pdf, tmp, on=["featName"])
splitpoint_data_pdf = splitpoint_data_pdf.dropna()
splitpoint_data_pdf_copy = splitpoint_data_pdf.copy()
self.allsplitpoint_pdf = splitpoint_data_pdf_copy
splitpoint_data_pdf = splitpoint_data_pdf[splitpoint_data_pdf.featpoint_num > 1]
featNames_ = list(set(splitpoint_data_pdf["featName"]))
featValuesall_list = []
featName_list = []
featValue_list = []
splitpoint_list = []
cnt0_list = []
cnt1_list = []
splitpoint_list = []
for featName_ in featNames_:
featValuesall = list(
splitpoint_data_pdf[(splitpoint_data_pdf["featName"] == featName_)][
"featValue"
]
)
cnt0s = list(
splitpoint_data_pdf[(splitpoint_data_pdf["featName"] == featName_)][
"cnt0"
]
)
cnt1s = list(
splitpoint_data_pdf[(splitpoint_data_pdf["featName"] == featName_)][
"cnt1"
]
)
featValuesall_list.append(featValuesall)
for i in range(len(featValuesall) - 1):
cnt0 = np.sum(cnt0s[0 : i + 1])
cnt1 = np.sum(cnt1s[0 : i + 1])
cnt0_list.append(cnt0)
cnt1_list.append(cnt1)
featName_list.append(featName_)
splitpoint_list.append(dict({featName_: featValuesall[0 : i + 1]}))
featValue_list.append(featValuesall[0 : i + 1])
splitpoint_pdf = pd.DataFrame(
zip(featName_list, featValue_list, splitpoint_list, cnt0_list, cnt1_list),
columns=["featName", "featValue", "splitpoint", "cnt0", "cnt1"],
)
# print("splitpoint_pdf_before:\n",splitpoint_pdf)
splitpoint_pdf = splitpoint_pdf[
(splitpoint_pdf["cnt0"] > 100)
& (splitpoint_pdf["cnt1"] > 100)
& (
splitpoint_pdf["cnt0"] + splitpoint_pdf["cnt1"]
> threshold * self.dat_cnt
)
& (
self.nodeSize - splitpoint_pdf["cnt0"] - splitpoint_pdf["cnt1"]
> threshold * self.dat_cnt
)
]
self.splitpoint_pdf = splitpoint_pdf
self.featvalues_dict = dict(zip(featNames_, featValuesall_list))
t2 = time.time()
def splitcond_sql(self, splitpoint):
featName = list(splitpoint.keys())[0]
featValue = list(splitpoint.values())[0]
x = self.sql_instance.sql(f"desc {self.dfName};")
cols_type = dict(zip(x["name"], x["type"]))
if cols_type[featName] == "String":
featValue = ",".join(["'" + str(i) + "'" for i in featValue])
else:
featValue = ",".join([str(i) for i in featValue])
left_condition_tree = f"""({featName} in ({featValue}))"""
right_condition_tree = f"""not ({featName} in ({featValue}))"""
return left_condition_tree, right_condition_tree
def calculate_impurity_new(self, x):
allsplitpoint_pdf = self.allsplitpoint_pdf # TODO
featName = x["featName"]
featValue = x["featValue"]
# left impurity
(y1, y1_square, cnt1, y0, y0_square, cnt0) = list(
allsplitpoint_pdf[
(allsplitpoint_pdf["featValue"].isin(featValue))
& (allsplitpoint_pdf["featName"] == featName)
][["y1", "y1_square", "cnt1", "y0", "y0_square", "cnt0"]].sum()
)
y1 = y1 / cnt1
y0 = y0 / cnt0
y1_square = y1_square / cnt1
y0_square = y0_square / cnt0
tau = y1 - y0
tr_var = y1_square - y1**2
con_var = y0_square - y0**2
left_effect = 0.5 * tau * tau * (cnt1 + cnt0) - 0.5 * 2 * (cnt1 + cnt0) * (
tr_var / cnt1 + con_var / cnt0
)
# right impurity
(y1, y1_square, cnt1, y0, y0_square, cnt0) = list(
allsplitpoint_pdf[
(~(allsplitpoint_pdf["featValue"].isin(featValue)))
& (allsplitpoint_pdf["featName"] == featName)
][["y1", "y1_square", "cnt1", "y0", "y0_square", "cnt0"]].sum()
)
y1 = y1 / cnt1
y0 = y0 / cnt0
y1_square = y1_square / cnt1
y0_square = y0_square / cnt0
tau = y1 - y0
tr_var = y1_square - y1**2
con_var = y0_square - y0**2
right_effect = 0.5 * tau * tau * (cnt1 + cnt0) - 0.5 * 2 * (cnt1 + cnt0) * (
tr_var / cnt1 + con_var / cnt0
)
return left_effect, right_effect
def calculate_impurity_original(self):
sql_instance = self.sql_instance
sql = f""" select\
sum(if(treatment=1,1,0)) as cnt1, \
sum(if(treatment=1,Y,0))/sum(if(treatment=1,1,0)) as y1, \
sum(if(treatment=1,Y*Y,0))/sum(if(treatment=1,1,0)) as y1_square, \
sum(if(treatment=0,1,0)) as cnt0, \
sum(if(treatment=0,Y,0))/sum(if(treatment=0,1,0)) as y0, \
sum(if(treatment=0,Y*Y,0))/sum(if(treatment=0,1,0)) as y0_square \
from {self.dfName} where if_test = 0 """
row = sql_instance.sql(sql).iloc[0, :]
cnt1, y1, y1_square, cnt0, y0, y0_square = (
int(row.cnt1),
float(row.y1),
float(row.y1_square),
int(row.cnt0),
float(row.y0),
float(row.y0_square),
)
tau = y1 - y0
tr_var = y1_square - y1**2
con_var = y0_square - y0**2
effect = 0.5 * tau * tau * (cnt1 + cnt0) - 0.5 * 2 * (cnt1 + cnt0) * (
tr_var / cnt1 + con_var / cnt0
)
return effect
def getTreeID(self):
nodePosition = self.nodePosition
lengthStr = len(nodePosition)
if lengthStr == 1:
result = 1
else:
result = sum([math.pow(2, i) for i in range(lengthStr - 1)]) + 1
for i in range(lengthStr):
x = 1 if nodePosition[i] == "R" else 0
result += math.pow(2, lengthStr - i - 1) * x
return result
def get_node_type(self):
if self.isLeaforNot == False:
return "internal"
else:
return "leaf"
def getBestSplit(self):
# print("-----getBestSplit-----")
splitpoint_pdf = self.splitpoint_pdf
if splitpoint_pdf.shape[0] == 0:
self.isLeaforNot = True
print(
"no split points that satisfy the condition,stop splitting as a leaf node"
)
return
splitpoint_pdf[["leftImpurity", "rightImpurity"]] = list(
splitpoint_pdf.apply(self.calculate_impurity_new, axis=1)
)
splitpoint_pdf["ImpurityGain"] = (
splitpoint_pdf["leftImpurity"]
+ splitpoint_pdf["rightImpurity"]
- self.impurity
)
splitpoint_pdf = splitpoint_pdf.sort_values(by="ImpurityGain", ascending=False)
splitpoint_pdf = splitpoint_pdf[(splitpoint_pdf["ImpurityGain"] > 0)]
self.splitpoint_pdf = splitpoint_pdf[
["featName", "featValue", "splitpoint", "ImpurityGain"]
]
if splitpoint_pdf.shape[0] == 0:
self.isLeaforNot = True
print(
"no split points that satisfy the condition,stop splitting as a leaf node"
)
return
bestFeatureName = splitpoint_pdf.iloc[0]["featName"]
bestFeatureIndex = splitpoint_pdf.iloc[0]["featValue"]
bestleftImpurity = splitpoint_pdf.iloc[0]["leftImpurity"]
bestrightImpurity = splitpoint_pdf.iloc[0]["rightImpurity"]
maxImpurityGain = splitpoint_pdf.iloc[0]["ImpurityGain"]
bestsplitpoint = splitpoint_pdf.iloc[0]["splitpoint"]
self.maxImpurGain = maxImpurityGain
self.splitFeat = bestFeatureName
self.splitIndex = bestFeatureIndex
self.splitpoint = bestsplitpoint
self.leftImpurity = bestleftImpurity
self.rightImpurity = bestrightImpurity
def buildTree(self):
# global nodesSet
self.nodesSet.append(self)
# print(self.nodesSet)
if self.whereCond == "":
self.impurity = self.calculate_impurity_original()
# print("root impurity:",self.impurity)
self.compute_df()
if self.isLeaforNot:
return
else:
self.getBestSplit()
if self.isLeaforNot:
return
# print("end split")
whereConditionSql = "" if (self.whereCond == "") else self.whereCond + " AND "
leftcondition_tree, rightcondition_tree = self.splitcond_sql(self.splitpoint)
leftCond = f"{whereConditionSql} ( {leftcondition_tree} ) "
rightCond = f"{whereConditionSql} ( {rightcondition_tree} ) "
leftChildDfName = self.nodePosition + "L"
rightChildDfName = self.nodePosition + "R"
# TODO:
father_split_feature = self.splitFeat
Categories = self.featvalues_dict[father_split_feature]
left_Categories = self.splitpoint[father_split_feature]
right_Categories = []
for i in Categories:
if i not in left_Categories:
right_Categories.append(i)
left_tmp = {}
right_tmp = {}
# print("self.whereCond_new",self.whereCond_new)
# print(father_split_feature,left_Categories,right_Categories)
left_tmp[father_split_feature] = left_Categories
right_tmp[father_split_feature] = right_Categories
leftCond_new = self.whereCond_new.copy()
leftCond_new.append(left_tmp)
rightCond_new = self.whereCond_new.copy()
rightCond_new.append(right_tmp)
# print("leftCond_new,rightCond_new",leftCond_new,rightCond_new)
leftChild = CausalTreeclass(
self.dfName,
self.threshold,
self.maxDepth - 1,
leftCond,
self.nodePosition + "L",
self.leftImpurity,
father_split_feature,
left_Categories,
leftCond_new,
self.nodesSet,
)
rightChild = CausalTreeclass(
self.dfName,
self.threshold,
self.maxDepth - 1,
rightCond,
self.nodePosition + "R",
self.rightImpurity,
father_split_feature,
right_Categories,
rightCond_new,
self.nodesSet,
)
leftChild.get_global_values(self.dat_cnt, self.depth, self.featNames)
rightChild.get_global_values(self.dat_cnt, self.depth, self.featNames)
self.leftNode = leftChild
self.rightNode = rightChild
# print("-----split result----")
# print("maxImpurGain:",self.maxImpurGain)
# print("splitpoint:",self.splitpoint)
# print("leftsplitcond,rightsplitcond:",leftcondition_tree,rightcondition_tree)
# print("leftImpurity,rightImpurity:",self.leftImpurity,self.rightImpurity)
print(
"--------start leftNode - build -- depth: {depth}, nodePosition: {nodePosition}--------".format(
depth=self.depth - self.maxDepth, nodePosition=self.nodePosition + "L"
)
)
self.leftNode.buildTree()
print(
"--------start rightNode - build -- depth: {depth}, nodePosition: {nodePosition}--------".format(
depth=self.depth - self.maxDepth, nodePosition=self.nodePosition + "R"
)
)
self.rightNode.buildTree()
def visualization(self):
if self.isLeaforNot:
return (
"\n" + "Prediction is: " + str(self.prediction) + " " + self.whereCond
)
else:
return (
"Level"
+ str(self.maxDepth)
+ " "
+ self.splitpoint
+ self.maxImpurGain
+ "\n"
+ self.leftNode.visualization()
+ "\n"
+ self.rightNode.visualization()
)
def ComputePvalueAndCI(self, zValue, prediction, meanStd):
pvalue = 2 * (1 - norm.cdf(x=abs(zValue), loc=0, scale=1))
lowerCI = prediction - 1.96 * (meanStd)
upperCI = prediction + 1.96 * (meanStd)
return (pvalue, lowerCI, upperCI)
def predictNode(self, ate=0, cnt=1):
if self.whereCond != "":
whereCond_ = "AND" + self.whereCond
else:
whereCond_ = ""
level = self.depth - self.maxDepth
TreeID = self.getTreeID()
isLeaf = True if (self.isLeaforNot) else False
whereCond_new = self.whereCond_new
ate = ate
sql = f"""SELECT \
treatment, count(*) as cnt, avg(Y) as mean, varSamp(Y) as var \
FROM {self.dfName} \
WHERE treatment in (0,1) {whereCond_} and if_test=1 \
GROUP BY treatment \
"""
sql_instance = self.sql_instance
statData_pdf = sql_instance.sql(sql)
try:
statData_pdf = statData_pdf.astype(float)
y1_column = statData_pdf[statData_pdf["treatment"] == 1]
y0_column = statData_pdf[statData_pdf["treatment"] == 0]
treatedCount = int(list(y1_column["cnt"])[0])
controlCount = int(list(y0_column["cnt"])[0])
treatedLabelAvg = float(list(y1_column["mean"])[0])
controlLabelAvg = float(list(y0_column["mean"])[0])
treatedLabelVar = float(list(y1_column["var"])[0])
controlLabelVar = float(list(y0_column["var"])[0])
self.prediction = treatedLabelAvg - controlLabelAvg
ratio = (treatedCount + controlCount) / cnt
except:
# print(statData,whereCond_)
self.inferenceSet = list(
[
TreeID,
level,
isLeaf,
whereCond_,
whereCond_new,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
1,
0,
0,
0,
]
)
return
# if treatedLabelAvg == 0 or controlLabelAvg == 0:
# self.inferenceSet = list(
# [TreeID, level, isLeaf, whereCond_, whereCond_new, ratio, self.prediction, treatedCount, controlCount,
# treatedLabelAvg, controlLabelAvg, math.sqrt(treatedLabelVar), math.sqrt(controlLabelVar),
# 0, 0, 1, 0, 0, 0,
# 0, 0, 1, 0, 0, 0,
# 0, 0, 1, 0, 0, 0,
# 0, 0, 1, 0, 0, 0
# ]) # TODO:fix zero bug
try:
estPoint1 = treatedLabelAvg - controlLabelAvg
std1 = math.sqrt(
treatedLabelVar / treatedCount + controlLabelVar / controlCount
)
zValue1 = estPoint1 / std1
(pvalue1, lowerCI1, upperCI1) = self.ComputePvalueAndCI(
zValue1, estPoint1, std1
)
estPoint2 = treatedLabelAvg - controlLabelAvg - ate
std2 = std1
zValue2 = estPoint2 / std2
(pvalue2, lowerCI2, upperCI2) = self.ComputePvalueAndCI(
zValue2, estPoint2, std2
)
estPoint3 = treatedLabelAvg / controlLabelAvg - 1
std3 = std1
zValue3 = zValue1
pvalue3 = pvalue1
lowerCI3 = estPoint3 - 1.96 * (std1) * estPoint3 / estPoint1
upperCI3 = estPoint3 + 1.96 * (std1) * estPoint3 / estPoint1
estPoint4 = (treatedLabelAvg - ate) / controlLabelAvg - 1
std4 = std2
zValue4 = zValue2
pvalue4 = pvalue2
if TreeID == 1:
estPoint2, estPoint4 = 0, 0
lowerCI4 = 0
upperCI4 = 0
else:
lowerCI4 = estPoint4 - 1.96 * (std2) * estPoint4 / estPoint2
upperCI4 = estPoint4 + 1.96 * (std2) * estPoint4 / estPoint2
self.inferenceSet = list(
[
TreeID,
level,
isLeaf,
whereCond_,
whereCond_new,
ratio,
self.prediction,
treatedCount,
controlCount,
treatedLabelAvg,
controlLabelAvg,
math.sqrt(treatedLabelVar),
math.sqrt(controlLabelVar),
estPoint1,
std1,
pvalue1,
zValue1,
lowerCI1,
upperCI1,
estPoint2,
std2,
pvalue2,
zValue2,
lowerCI2,
upperCI2,
estPoint3,
std3,
pvalue3,
zValue3,
lowerCI3,
upperCI3,
estPoint4,
std4,
pvalue4,
zValue4,
lowerCI4,
upperCI4,
]
)
except:
self.inferenceSet = list(
[
TreeID,
level,
isLeaf,
whereCond_,
whereCond_new,
ratio,
self.prediction,
treatedCount,
controlCount,
treatedLabelAvg,
controlLabelAvg,
math.sqrt(treatedLabelVar),
math.sqrt(controlLabelVar),
0,
0,
1,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
1,
0,
0,
0,
]
) # TODO:fix zero bug
# define tree structure
class CTDecisionNode:
def __init__(
self,
node_id=0,
nodeType="",
node_order="",
nodePosition="",
count_ratio=0,
controlCount=0,
treatedCount=0,
treatedAvg=0,
controlAvg=0,
splitType="",
gain=0,
impurity=0,
prediction=0,
prediction_new=[],
featureName="",
featureIndex="",
father_split_feature="",
father_split_feature_Categories=[],
pvalues=[],
qvalues=[],
children=None,
whereCond="",
whereCond_new=[],
):
self.node_id = node_id
self.nodeType = nodeType
self.node_order = node_order
self.nodePosition = nodePosition
self.count_ratio = count_ratio
self.controlCount = controlCount
self.treatedCount = treatedCount
self.treatedAvg = treatedAvg
self.controlAvg = controlAvg
self.splitType = splitType
self.gain = gain
self.impurity = impurity
self.prediction = prediction
self.prediction_new = prediction_new
self.featureName = featureName
self.featureIndex = featureIndex
self.father_split_feature = father_split_feature
self.father_split_feature_Categories = father_split_feature_Categories
self.pvalues = pvalues
self.qvalues = qvalues
self.whereCond = whereCond
self.whereCond_new = whereCond_new
self.children = children
def get_dict(self):
child_dic = (
[]
if self.children == None
else (self.children[0].get_dict(), self.children[1].get_dict())
)
return {
"node_id": self.node_id,
"nodeType": self.nodeType,
"father_split_feature_Categories": self.father_split_feature_Categories,
"count_ratio": self.count_ratio,
"controlCount": self.controlCount,
"treatedCount": self.treatedCount,
"controlAvg": self.controlAvg,
"treatedAvg": self.treatedAvg,
"father_split_feature": self.father_split_feature,
"pvalues": self.pvalues,
"tau_i": self.prediction,
# "node_order": self.node_order,
# "nodePosition":self.nodePosition,
# "tau_i_new":self.prediction_new,
# "splitType": self.splitType,
# "gain": self.gain,
# "impurity": self.impurity,
# "featureName": self.featureName,
# "featureIndex": self.featureIndex,
# "qvalues":self.qvalues,
# "whereCond":self.whereCond,
# "whereCond_new":self.whereCond_new,
"children": child_dic,
}
# causaltree_todict for tree.plot
class CTRegressionTree:
def __init__(self, tree=0, total_count=1):
self.tree = tree
self.total_count = total_count
def get_decision_rules(self, node_order="root"):
tree = self.tree
node_id = tree.getTreeID()
node_type = tree.get_node_type()
node_order = node_order
# 11.1 sort _Category values
x = tree.father_split_feature_Categories
# x = [int(i) for i in x]
x.sort()
# father_split_feature_Categories = [str(i) for i in x]
if node_type == "internal":
gain = tree.maxImpurGain
feature_name = tree.splitFeat
feature_index = tree.splitIndex
split_type = "categorical"
node_impurity = tree.impurity
# Categories = tree.featvalues_dict[feature_name]
# left_Categories = list(tree.splitIndex.split(','))
# right_Categories = []
# for i in Categories:
# if i not in left_Categories:
# right_Categories.append(i)
left = CTRegressionTree(tree=tree.leftNode, total_count=self.total_count)
right = CTRegressionTree(tree=tree.rightNode, total_count=self.total_count)
children = (
left.get_decision_rules("left"),
right.get_decision_rules("right"),
)
else:
gain = None
feature_name = None
feature_index = None
split_type = None
node_impurity = None
# left_Categories = None
# right_Categories = None
children = None
ctDecisionNode = CTDecisionNode(
node_id=node_id,
nodeType=node_type,
node_order=node_order,
count_ratio=round(tree.inferenceSet["ratio"] * 100, 2),
treatedCount=tree.inferenceSet["treatedCount"],
controlCount=tree.inferenceSet["controlCount"],
treatedAvg=round(tree.inferenceSet["treatedAvg"], 4),
controlAvg=round(tree.inferenceSet["controlAvg"], 4),
nodePosition=tree.nodePosition,
prediction=round(tree.prediction, 4),
prediction_new=[
float(tree.inferenceSet["estPoint" + str(i)]) for i in range(1, 5)
],
splitType=split_type,
gain=gain,
impurity=node_impurity,
featureName=feature_name,
featureIndex=feature_index,
father_split_feature=tree.father_split_feature,
father_split_feature_Categories=tree.father_split_feature_Categories,
pvalues=[
round(float(tree.inferenceSet["pvalue" + str(i)]), 2)
for i in range(1, 5)
],
qvalues=[
round(float(tree.inferenceSet["qvalue" + str(i)]), 2)
for i in range(1, 5)
],
children=children,
whereCond=tree.whereCond,
whereCond_new=tree.whereCond_new,
)
return ctDecisionNode
def SelectSchema(schema):
return schema
def FeatNames(schema):
if schema == "":
schemaArray = []
elif "+" in schema:
schemaArray = schema.split("+")
else:
schemaArray = [schema]
return schemaArray
def FilterSchema(schemaArray):
tmp = [i + " is not null" for i in schemaArray]
return " and ".join(tmp)
def auto_wrap_text(text, max_line_length):
wrapped_text = textwrap.fill(text, max_line_length)
return wrapped_text
def check_table(table):
sql_instance = create_sql_instance()
x = sql_instance.sql(f"select count(*) as cnt from {table} ")
if "Code: 60" in x:
print(x)
raise ValueError
elif int(x["cnt"][0]) == 0:
print("There's no data in the table")
raise ValueError
else:
return 1
def check_numeric_type(table, col):
sql_instance = create_sql_instance()
x = sql_instance.sql(f"desc {table}")
cols_type = dict(zip(x["name"], x["type"]))
if cols_type[col] not in [
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"UInt256",
"Int8",
"Int16",
"Int32",
"Int64",
"Int128",
"Int256",
"Float32",
"Float64",
]:
print(f"The type of {col} is not numeric")
return 1
def check_columns(table, cols, cols_nume):
sql_instance = create_sql_instance()
x = sql_instance.sql(f"desc {table} ")
cols_type = dict(zip(x["name"], x["type"]))
col_list = list(cols_type.keys())
# check exist
other_variables = set(cols) - set(col_list)
if len(other_variables) != 0:
print(f"variable {other_variables} can't be find in the table {table}")
raise ValueError
# numeric exist
for col in cols_nume:
if cols_type[col] not in [
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"UInt256",
"Int8",
"Int16",
"Int32",
"Int64",
"Int128",
"Int256",
"Float32",
"Float64",
]:
print(f"The type of {col} is not numeric")
raise ValueError
[docs]class CausalTree:
"""
This class implements a Causal Tree for uplift/HTE analysis.
Parameters
----------
depth : int
The maximum depth of the tree.
threshold : float
The minimum sample ratio for a leaf.
bin_num : int
The number of bins for the need_cut_X to cut.
Example
-------
.. code-block:: python
import fast_causal_inference
from fast_causal_inference.dataframe.uplift import *
Y='y'
T='treatment'
table = 'test_data_small'
X = ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
needcut_X = ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
df = fast_causal_inference.readClickHouse(table)
df_train, df_test = df.split(0.5)
hte = CausalTree(depth = 3,min_sample_ratio_leaf=0.001)
hte.fit(Y,T,X,needcut_X,df_train)
treeplot = hte.treeplot() # causal tree plot
treeplot.render('digraph.gv', view=False) # 可以在digraph.gv.pdf文件里查看tree的完整图片并下载
print(hte.feature_importance)
# Output:
# featName importance
# 1 x2_buckets 1.015128e+06
# 0 x1_buckets 2.181346e+05
# 3 x4_buckets 1.023273e+05
# 5 x_long_tail1_buckets 5.677131e+04
# 2 x3_buckets 2.537835e+04
# 6 x_long_tail2_buckets 2.536951e+04
# 4 x5_buckets 7.259992e+03
df_train_pred = hte.effect(df=df_train,keep_col='*')
df_test_pred = hte.effect(df=df_test,keep_col='*')
lift_train = get_lift_gain("effect", Y, T, df_train_pred,discrete_treatment=True, K=100)
lift_test = get_lift_gain("effect", Y, T, df_test_pred,discrete_treatment=True, K=100)
print(lift_train,lift_test)
hte_plot([lift_train,lift_test],labels=['train','test'])
# auuc: 0.6624369283393814
# auuc: 0.6532554148698826
# ratio lift gain ate ramdom_gain
# 0 0.009990 2.164241 0.021621 1.0 0.009990
# 1 0.019980 2.131245 0.042582 1.0 0.019980
# 2 0.029970 2.056440 0.061632 1.0 0.029970
# 3 0.039960 2.177768 0.087024 1.0 0.039960
# 4 0.049950 2.175329 0.108658 1.0 0.049950
# .. ... ... ... ... ...
# 95 0.959241 1.015223 0.973843 1.0 0.959241
# 96 0.969431 1.010023 0.979147 1.0 0.969431
# 97 0.979620 1.006843 0.986324 1.0 0.979620
# 98 0.989810 1.003508 0.993283 1.0 0.989810
# 99 1.000000 1.000000 1.000000 1.0 1.000000
# [100 rows x 5 columns] ratio lift gain ate ramdom_gain
# 0 0.009810 1.948220 0.019112 1.0 0.009810
# 1 0.019620 2.221654 0.043588 1.0 0.019620
# 2 0.029429 2.419752 0.071212 1.0 0.029429
# 3 0.039239 2.288460 0.089797 1.0 0.039239
# 4 0.049049 2.343432 0.114943 1.0 0.049049
# .. ... ... ... ... ...
# 95 0.959960 1.014897 0.974260 1.0 0.959960
# 96 0.969970 1.011624 0.981245 1.0 0.969970
# 97 0.979980 1.009358 0.989150 1.0 0.979980
# 98 0.989990 1.006340 0.996267 1.0 0.989990
# 99 1.000000 1.000000 1.000000 1.0 1.000000
# [100 rows x 5 columns]
"""
def __init__(self, depth=3, min_sample_ratio_leaf=0.001, bin_num=10):
self.depth = depth
self.threshold = min_sample_ratio_leaf
self.bin_num = bin_num
self.Y = ""
self.T = ""
self.X = ""
self.needcut_X = ""
self.table = ""
self.tree_structure = []
self.result_df = []
self.__sql_instance = SqlGateWayConn.create_default_conn()
self.__cutbinstring = ""
self.feature_importance = pd.DataFrame([])
def __params_input_check(self):
sql_instance = self.__sql_instance
if self.Y == "":
print("missing Y. You should check out the input.")
raise ValueError
if self.T == "":
print("missing T. You should check out the input.")
raise ValueError
if self.X == "":
print("missing X. You should check out the input.")
raise ValueError
if self.table == "":
print("missing table. You should check out the input.")
raise ValueError
else:
sql_instance.sql(f"select {self.Y},{self.T},{self.X} from {self.table}")
def __table_variables_check(self):
sql_instance = self.__sql_instance
table = self.variables["table"]
Y = self.variables["Y"]
T = self.variables["T"]
x_names = self.variables["x_names"]
cut_x_names = self.variables["cut_x_names"]
variables = Y + T + x_names + cut_x_names
check_table(table=table)
check_columns(table=table, cols=variables, cols_nume=Y + T + cut_x_names)
res = sql_instance.sql(
f"select {T[0]} as T from {table} group by {T[0]} limit 10"
)["T"].tolist()
T_value = set([float(i) for i in res])
if T_value != {0, 1}:
print("The value of T must be either 0 or 1 ")
raise ValueError
def fit(self, Y, T, X, needcut_X, df):
sql_instance = self.__sql_instance
table = df_2_table(df)
table_new = f"{table}_{int(time.time())}_new"
self.Y = Y
self.T = T
self.table = table
self.X = X
self.needcut_X = needcut_X
depth = self.depth
bin_num = self.bin_num
self.__params_input_check()
x_names = list(set(X))
cut_x_names = list(set(needcut_X))
no_cut_x_names = list(set(x_names) - set(cut_x_names))
cut_x_names_new = [i + "_buckets" for i in cut_x_names]
x_names = no_cut_x_names + cut_x_names
x_names_new = no_cut_x_names + cut_x_names_new
featNames = x_names_new
self.variables = {
"T": [T],
"Y": [Y],
"x_names": x_names,
"cut_x_names": cut_x_names,
"table": table,
}
print("****STEP1. Table check.")
### todo
print("debug")
self.__table_variables_check()
# get bins for cut_x_names
print("****STEP2. Bucket the continuous variables(cut_x_names).")
bins_dict = {}
quantiles = ",".join(
[str(i) for i in list(np.linspace(0, 1, bin_num + 1)[1:-1])]
)
if len(cut_x_names_new) != 0:
string = ",".join(
[f"quantiles({quantiles})({i}) as {i}" for i in cut_x_names]
)
result = sql_instance.sql(f"""select {string} from {table}""")
bins_dict = {}
for i in range(len(cut_x_names)):
col = cut_x_names[i]
x = result[col][0]
bins = x.replace("[", "").replace("]", "").split(",")
if len(bins) == 0:
bins = [0]
bins = list(np.sort(list(set([float(x) for x in bins]))))
bins_dict[col] = bins
strings = []
for i in bins_dict:
string = f"CutBins({i},{bins_dict[i]},False) as {i}_buckets"
strings.append(string)
cutbinstring = ",".join(strings) + ","
else:
cutbinstring = ""
self.bins_dict = bins_dict
for i in bins_dict:
bins_dict[i] = [-float("inf")] + bins_dict[i] + [float("inf")]
if len(no_cut_x_names) != 0:
no_cut_x_names_string = (
",".join([f"{i} as {i}" for i in no_cut_x_names]) + ","
)
else:
no_cut_x_names_string = ""
self.__cutbinstring = cutbinstring
self.__no_cut_x_names_string = no_cut_x_names_string
print("****STEP3. Create new table for training causaltree: ", table_new, ".")
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_new,
sql_statement=f"""
{Y} as Y, {T} as treatment,{cutbinstring}{no_cut_x_names_string} if(rand()/pow(2,32)<0.5,0,1) as if_test
""",
sql_table_name=table,
primary_column="if_test",
is_force_materialize=True,
)
# check if empty data for new table
allcnt = int(
sql_instance.sql(
f"select count(*) as cnt from {table_new} where {FilterSchema(x_names_new)}"
)["cnt"][0]
)
if allcnt == 0:
print("Sample size is 0, check for null values")
raise ValueError
res = sql_instance.sql(
f"select treatment from {table_new} group by treatment limit 10"
)["treatment"].tolist()
treatments = set([float(i) for i in res])
if treatments != {0, 1}:
print("The value of T can only be 0 or 1")
raise ValueError
# compute ate before build tree
train = sql_instance.sql(
f"SELECT if(treatment=1,1,-1) as z, sum(Y) as sum,count(*) as cnt FROM {table_new} where if_test=0 group by if(treatment=1,1,-1)"
)
train = pd.DataFrame(train, columns=["z", "sum", "cnt"])
train = train[["z", "sum", "cnt"]].astype(float)
test = sql_instance.sql(
f"SELECT if(treatment=1,1,-1) as z, sum(Y) as sum,count(*) as cnt FROM {table_new} where if_test=1 group by if(treatment=1,1,-1)"
)
test = pd.DataFrame(test, columns=["z", "sum", "cnt"])
test = test[["z", "sum", "cnt"]].astype(float)
dat_cnt = train["cnt"].sum()
estData_cnt = test["cnt"].sum()
data_all = pd.merge(train, test, on="z")
ate = (
data_all["z"]
* (data_all["sum_x"] + data_all["sum_y"])
/ (data_all["cnt_x"] + data_all["cnt_y"])
).sum()
estData_ate = ((test["z"] * (test["sum"])) / (test["cnt"])).sum()
print(
f"\t train data samples: {int(dat_cnt)},predict data samples: {int(estData_cnt)}"
)
# build tree
print("****STEP4. Build tree.")
t1 = time.time()
modelTree = CausalTreeclass(
dfName=table_new,
threshold=self.threshold,
maxDepth=depth,
whereCond="",
nodePosition="L",
impurity=0,
father_split_feature="",
father_split_feature_Categories=[],
whereCond_new=[],
nodesSet=[],
)
modelTree.get_global_values(dat_cnt, depth, featNames)
print(
"================================== start buildTree -- maxDepth: {maxDepth}, nodePosition: {nodePosition}==================================".format(
maxDepth=depth, nodePosition="root"
)
)
modelTree.buildTree()
print(
"============================================== build Tree Sucessfully====================================================="
)
result_list = []
print("****STEP5. Estimate CATE.")
nodesSet = modelTree.nodesSet
splitpoint_pdf_all = pd.DataFrame(
[], columns=["featName", "featValue", "splitpoint", "ImpurityGain"]
)
for l in nodesSet:
l.predictNode(ate=estData_ate, cnt=estData_cnt) # TODO
result = l.inferenceSet
splitpoint_pdf_all = pd.concat(
[splitpoint_pdf_all, l.splitpoint_pdf], axis=0
)
result_list.append(result)
self.feature_importance = splitpoint_pdf_all.groupby(
by=["featName"], as_index=False
)["ImpurityGain"].sum()
self.feature_importance = self.feature_importance.sort_values(
by=["ImpurityGain"], ascending=False
)
self.feature_importance.columns = ["featName", "importance"]
columns = [
"TreeID",
"level",
"isLeaf",
"whereCond",
"whereCond_new",
"ratio",
"prediction",
"treatedCount",
"controlCount",
"treatedAvg",
"controlAvg",
"treatedStd",
"controlStd",
"estPoint1",
"std1",
"pvalue1",
"zValue1",
"lowerCI1",
"upperCI1",
"estPoint2",
"std2",
"pvalue2",
"zValue2",
"lowerCI2",
"upperCI2",
"estPoint3",
"std3",
"pvalue3",
"zValue3",
"lowerCI3",
"upperCI3",
"estPoint4",
"std4",
"pvalue4",
"zValue4",
"lowerCI4",
"upperCI4",
]
result_df = pd.DataFrame(result_list, columns=columns)
print(
"============================================== compute Tree Sucessfully====================================================="
)
result_df["qvalue1"] = fdrcorrection(result_df["pvalue1"])[1]
result_df["qvalue2"] = fdrcorrection(result_df["pvalue2"])[1]
result_df["qvalue3"] = fdrcorrection(result_df["pvalue3"])[1]
result_df["qvalue4"] = fdrcorrection(result_df["pvalue4"])[1]
self.result_df = result_df
# for continous variable, give the cut bins mapping
# for example: {"x_continuous_1_buckets":[4]} >>> {"x_continuous_1_buckets":'[89.62761587688087,95.04490061748047)'}
if result_df.shape[0] > 1:
Category_value_dicts = {}
for i in x_names_new:
try:
values_ = list(
modelTree.allsplitpoint_pdf[
modelTree.allsplitpoint_pdf["featName"] == i
]["featValue"]
)
values_ = list(np.array(values_))
values_.sort()
except:
print(values_)
if i in cut_x_names_new:
cut_bins_dict = {}
bins = bins_dict[i.split("_buckets")[0]]
intevals = []
for j in range(len(bins) - 1):
intevals.append(
"["
+ str(round(bins[j], 4))
+ ","
+ str(round(bins[j + 1], 4))
+ ")"
)
for value in values_:
cut_bins_dict[value] = intevals[int(value) - 1] # todo
Category_value_dicts[i] = cut_bins_dict
else:
Category_value_dicts[i] = dict(zip(values_, values_))
self.Category_value_dicts = Category_value_dicts
whereCond_new_list = []
for i in range(1, len(nodesSet)):
whereCond_new_ = {}
for whereCond_new in result_df.loc[i, "whereCond_new"]:
key = list(whereCond_new.keys())[0]
value = list(whereCond_new.values())[0]
value.sort()
Category_value_dict = Category_value_dicts[key]
whereCond_new_[key] = [Category_value_dict[j] for j in value]
whereCond_new_list.append(whereCond_new_)
result_df["whereCond_new"] = [""] + whereCond_new_list
nodesSet[0].inferenceSet = dict(result_df.loc[0, :])
for i in range(1, len(nodesSet)):
nodesSet[i].inferenceSet = dict(
result_df.loc[i, :]
) # inferenceSet: list >> dict
nodesSet[i].whereCond_new = result_df.loc[i, "whereCond_new"]
Category_value_dict = Category_value_dicts[nodesSet[i].father_split_feature]
father_split_feature_Categories = nodesSet[
i
].father_split_feature_Categories.copy()
nodesSet[i].father_split_feature_Categories = [
Category_value_dict[j] for j in father_split_feature_Categories
]
# tree_structure
ctRegressionTree = CTRegressionTree(
tree=modelTree, total_count=test["cnt"].sum()
) # estData total cnt
tree_structure = ctRegressionTree.get_decision_rules("root")
self.tree_structure = tree_structure
self.result_df = result_df
self.estimate = list(self.result_df["prediction"])
self.estimate_stderr = list(self.result_df["std1"])
self.pvalue = list(self.result_df["pvalue1"])
self.estimate_interval = np.array((self.result_df[["lowerCI1", "upperCI1"]]))
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=table_new)
self.modelTree = modelTree
def __add_nodes_edges(self, tree, dot=None):
if dot is None:
dot = Digraph(
"g",
filename="btree.gv",
node_attr={
"shape": "record",
"height": ".1",
"width": "1",
"fontsize": "9",
},
)
dot = Digraph()
dot.node_attr.update(width="1", fontsize="9")
node_id = int(tree["node_id"])
node_id = f"Node : {node_id}"
ate = f'CATE : {tree["tau_i"]} (p_value:{tree["pvalues"][0]})'
sample = f'sample size: control {tree["controlCount"]}; treatment {tree["treatedCount"]}'
sample_ratio = f'sample_ratio:{tree["count_ratio"]}%'
mean = (
f'mean: control {tree["controlAvg"]}; treatment {tree["treatedAvg"]}'
)
dot.node(
str(tree["node_id"]),
"\n".join([node_id, sample_ratio, ate, sample, mean]),
)
for child in tree["children"]:
if child:
node_id = int(child["node_id"])
node_id = f"Node : {node_id}"
ate = f'CATE : {child["tau_i"]} (p_value:{child["pvalues"][0]})'
sample = f'sample size: control {child["controlCount"]}; treatment {child["treatedCount"]}'
sample_ratio = f'sample_ratio:{child["count_ratio"]}%'
mean = f'mean: control {child["controlAvg"]}; treatment {child["treatedAvg"]}'
split_criterion = f'split_criterion: {child["father_split_feature"]} in {child["father_split_feature_Categories"]}'
dot.node(
str(child["node_id"]),
"\n".join(
[
node_id,
auto_wrap_text(split_criterion, 60),
sample_ratio,
ate,
sample,
mean,
]
),
)
dot.edge(str(tree["node_id"]), str(child["node_id"]))
self.__add_nodes_edges(child, dot)
return dot
def treeplot(self):
tree_structure = self.tree_structure.get_dict()
dot = self.__add_nodes_edges(tree_structure)
return dot
def hte_plot(self):
# toc curve
result_df = self.result_df
toc_data_df = result_df[result_df["isLeaf"] == True][
["prediction", "treatedCount", "controlCount", "treatedAvg", "controlAvg"]
].sort_values(by="prediction", ascending=False)
toc_data_df.reset_index(drop=True, inplace=True)
toc_data_df["cnt"] = toc_data_df["treatedCount"] + toc_data_df["controlCount"]
toc_data_df["treatedsum"] = (
toc_data_df["treatedCount"] * toc_data_df["treatedAvg"]
)
toc_data_df["controlsum"] = (
toc_data_df["controlCount"] * toc_data_df["controlAvg"]
)
cnt = toc_data_df["cnt"].sum()
toc_data_df["treatedcumsum"] = toc_data_df["treatedsum"].cumsum()
toc_data_df["controlcumsum"] = toc_data_df["controlsum"].cumsum()
toc_data_df["treatedcumcnt"] = toc_data_df["treatedCount"].cumsum()
toc_data_df["controlcumcnt"] = toc_data_df["controlCount"].cumsum()
toc_data_df["ratio"] = toc_data_df["cnt"] / cnt
toc_data_df["ratio_sum"] = toc_data_df["ratio"].cumsum()
toc_data_df["toc"] = (
toc_data_df["treatedcumsum"] / toc_data_df["treatedcumcnt"]
- toc_data_df["controlcumsum"] / toc_data_df["controlcumcnt"]
)
toc_data_df["qini"] = (
toc_data_df["treatedcumsum"]
- toc_data_df["controlcumsum"]
/ toc_data_df["controlcumcnt"]
* toc_data_df["treatedcumcnt"]
)
toc_data_df["order"] = toc_data_df.index + 1
toc_data_df["x"] = toc_data_df["ratio_sum"]
ate = list(toc_data_df["toc"])[-1]
toc_data_df["toc1"] = ate
toc_data_df["qini1"] = toc_data_df["x"] * list(toc_data_df["qini"])[-1]
toc_data_df
toc_data = toc_data_df[["x", "toc", "toc1", "qini", "qini1"]]
toc_data.columns = [
"ratio",
"toc_tree",
"toc_random",
"qini_tree",
"qini_random",
]
toc_data = toc_data.sort_values(by=["ratio"])
result = toc_data
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=(12, 4.8))
ax1.plot(result["ratio"], result["toc_tree"], label="CausalTree")
ax2.plot(
[0] + list(result["ratio"]),
[0] + list(result["qini_tree"]),
label="CausalTree",
)
ax1.plot(result["ratio"], result["toc_random"], label="Random Model")
ax2.plot(
[0] + list(result["ratio"]),
[0] + list(result["qini_random"]),
label="Random Model",
)
ax1.set_title("Cumulative Lift Curve")
ax1.legend()
ax2.set_title("Cumulative Gain Curve")
ax2.legend()
fig.suptitle("CausalTree Lift and Gain Curves")
plt.show()
[docs] def effect(self, df, keep_col="*"):
"""
Calculate the individual treatment effect.
Parameters
----------
df : DataFrame
The input data.
keep_col : str, optional
The columns to keep. Defaults to '*'.
Returns
-------
DataFrame
The result data.
"""
table_input = df_2_table(df)
cutbinstring = self.__cutbinstring
table_tmp = f"{table_input}_{int(time.time())}_foreffect_2_clickhouse"
table_tmp1 = table_tmp + "1"
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_tmp,
sql_statement=f"""
*,{cutbinstring}1 as index
""",
sql_table_name=table_input,
primary_column="index",
is_force_materialize=True,
)
leaf_effect = np.array(
self.result_df[self.result_df["isLeaf"] == True][
["whereCond", "prediction"]
]
)
string = " ".join([f"when True {x[0]} then {x[1]} " for x in leaf_effect])
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_tmp1,
sql_statement=f"""
{keep_col},
case
{string}
else 4294967295
end as effect
""",
sql_table_name=table_tmp,
primary_column="effect",
is_force_materialize=True,
)
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=table_tmp)
return readClickHouse(table_tmp1)
def save_model(model, file):
# 将类实例序列化并保存到本地文件
# print("file must be of type 'pkl'. \nExample: save_model(model,'file_name.pkl')")
with open(file, "wb") as f:
pickle.dump(model, f)
print(f"model{model} is stored at '{file}'")
def load_model(file):
# 从本地文件中加载并反序列化类实例
with open(file, "rb") as f:
model = pickle.load(f)
return model
def save_model2ch(hte):
cutbin_string = hte._CausalTree__cutbinstring[:-1]
leaf_effect = np.array(
hte.result_df[hte.result_df["isLeaf"] == True][["whereCond", "prediction"]]
)
effect_string = " ".join([f"when True {x[0]} then {x[1]} " for x in leaf_effect])
model_table = f"tmp_{int(time.time())}"
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=model_table,
sql_statement=f"""
'{cutbin_string}' as cutbin_string, '{effect_string}' as effect_string
""",
sql_table_name=hte.table,
sql_limit=1,
is_force_materialize=True,
)
return readClickHouse(model_table)
[docs]class CausalForest:
"""
This class implements the Causal Forest method for causal inference.
Parameters
----------
depth : int, default=7
The maximum depth of the tree.
min_node_size : int, default=-1
The minimum node size.
mtry : int, default=3
The number of variables randomly sampled as candidates at each split.
num_trees : int, default=10
The number of trees to grow in the forest.
sample_fraction : float, default=0.7
The fraction of observations to consider when fitting the forest.
weight_index : str, default=''
The weight index.
honesty : bool, default=False
Whether to use honesty when fitting the forest.
honesty_fraction : float, default=0.5
The fraction of observations to use for determining splits if honesty is used.
quantile_num : int, default=50
The number of quantiles.
Methods
-------
fit(Y, T, X, df):
Fit the Causal Forest model to the input data.
effect(input_df=None, X=[]):
Estimate the causal effect using the fitted model.
Example
-------
.. code-block:: python
import fast_causal_inference
from fast_causal_inference.dataframe.uplift import *
Y='y'
T='treatment'
table = 'test_data_small'
X = ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
df = fast_causal_inference.readClickHouse(table)
df_train, df_test = df.split(0.5)
from fast_causal_inference.dataframe.uplift import CausalForest
model = CausalForest(depth=7, min_node_size=-1, mtry=3, num_trees=10, sample_fraction=0.7)
model.fit(Y, T, X, df_train)
"""
def __init__(
self,
depth=7,
min_node_size=-1,
mtry=3,
num_trees=10,
sample_fraction=0.7,
weight_index="",
honesty=False,
honesty_fraction=0.5,
quantile_num=50,
):
self.depth = depth
self.min_node_size = min_node_size
self.mtry = mtry
self.num_trees = num_trees
self.sample_fraction = sample_fraction
self.weight_index = weight_index
self.honesty = 0
if honesty == True:
self.honesty = 1
self.honesty_fraction = honesty_fraction
self.quantile_num = quantile_num
self.quantile_num = max(1, min(100, self.quantile_num))
[docs] def fit(self, Y, T, X, df):
"""
Fit the Causal Forest model to the input data.
Parameters
----------
Y : str
The outcome variable.
T : str
The treatment variable.
X : list
The numeric covariates. Strings are not supported. ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
df : DataFrame
The input dataframe.
Returns
-------
None
"""
self.table = df.getTableName()
self.origin_table = df.getTableName()
self.Y = Y
self.T = T
self.X = X
self.len_x = len(X)
self.ts = current_time_ms = int(time.time() * 1000)
# create model df.getTableName()
self.model_table = "model_" + self.table + str(self.ts)
sql_instance = SqlGateWayConn.create_default_conn()
count = sql_instance.sql("select count() as cnt from " + self.table)
if isinstance(count, str):
print(count)
return
self.table_count = count["cnt"][0]
self.mtry = min(30, self.mtry)
self.num_trees = min(200, self.num_trees)
calc_min_node_size = int(max(int(self.table_count) / 128, 1))
self.min_node_size = max(self.min_node_size, calc_min_node_size)
# insert into {self.model_table}
self.config = f""" select '{{"max_centroids":1024,"max_unmerged":2048,"honesty":{self.honesty},"honesty_fraction":{self.honesty_fraction}, "quantile_size":{self.quantile_num}, "weight_index":2, "outcome_index":0, "treatment_index":1, "min_node_size":{self.min_node_size}, "sample_fraction":{self.sample_fraction}, "mtry":{self.mtry}, "num_trees":{self.num_trees}}}' as model, {self.ts} as ver"""
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.model_table)
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.model_table)
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=self.model_table,
sql_statement=self.config,
is_sql_complete=True,
sql_table_name=self.table,
primary_column="ver",
is_force_materialize=False,
)
sql_instance.sql(self.config)
if self.weight_index == "":
self.weight_index = "1 / " + str(self.table_count)
self.xs = ",".join(X)
self.init_sql = f"""
insert into {self.model_table} (model, ver)
WITH
(select max(ver) from {self.model_table}) as ver0,
(
SELECT model
FROM {self.model_table}
WHERE ver = ver0 limit 1
) AS model
SELECT CausalForest(model)({Y}, {T}, {self.weight_index}, {self.xs}), ver0 + 1
FROM {self.table}
"""
res = sql_instance.sql(self.init_sql)
self.train_sql = f"""
insert into {self.model_table} (model, ver)
WITH
(select max(ver) from {self.model_table}) as ver0,
(
SELECT model
FROM {self.model_table}
WHERE ver = ver0 limit 1
) AS model,
(
SELECT CausalForest(model)({Y}, {T}, {self.weight_index}, {self.xs})
FROM {self.table}
) as calcnumerdenom,
(
SELECT CausalForest(calcnumerdenom)({Y}, {T}, {self.weight_index}, {self.xs})
FROM {self.table}
) as split_pre
SELECT CausalForest(split_pre)({Y}, {T}, {self.weight_index}, {self.xs}), ver0 + 1
FROM {self.table}
"""
for i in range(self.depth):
print("deep " + str(i + 1) + " train over")
res = sql_instance.execute(self.train_sql)
if isinstance(res, str) == True and res.find("train over") != -1:
print("----------训练结束----------")
break
[docs] def effect(self, df=None, X=[]):
"""
Estimate the causal effect using the fitted model.
Parameters
----------
df : DataFrame, default=None
The input dataframe for which to estimate the causal effect. If None, use the dataframe from the fit method.
X : list, default=[]
The covariates to use when estimating the causal effect. ['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2']
Returns
-------
DataFrame
The output dataframe with the estimated causal effect.
Example
-------
.. code-block:: python
df_test_effect_cf = model.effect(df=df_test, X=['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2'])
df_train_effect_cf = model.effect(df=df_train, X=['x1', 'x2', 'x3', 'x4', 'x5', 'x_long_tail1', 'x_long_tail2'])
lift_train = get_lift_gain("effect", Y, T, df_test_effect_cf,discrete_treatment=True, K=100)
lift_test = get_lift_gain("effect", Y, T, df_train_effect_cf,discrete_treatment=True, K=100)
print(lift_train,lift_test)
hte_plot([lift_train,lift_test],labels=['train','test'])
"""
# Your code here
if X != []:
len_x = len(X)
if len_x != self.len_x:
print("The number of x is not equal to the number of x in the model")
return
self.xs = ",".join(X)
if df != None:
self.effect_table = df.getTableName()
else:
self.effect_table = self.table
self.output_table = DataFrame.createTableName()
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.output_table)
self.predict_sql = f"""
WITH
(
SELECT max(ver)
FROM {self.model_table}
) AS ver0,
(
SELECT model
FROM {self.model_table}
WHERE ver = ver0 limit 1
) AS pure,
(SELECT CausalForestPredict(pure)({self.Y}, 0, {self.weight_index}, {self.xs}) FROM {self.origin_table}) as model,
(SELECT CausalForestPredictState(model)(number) FROM numbers(0)) as predict_model
select *, evalMLMethod(predict_model, 0, {self.weight_index}, {self.xs}) as effect
FROM {self.effect_table}
"""
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=self.output_table)
ClickHouseUtils.clickhouse_drop_view(
clickhouse_view_name=self.output_table + "_local"
)
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=self.output_table,
sql_statement=self.predict_sql,
is_sql_complete=True,
sql_table_name=self.model_table,
primary_column="effect",
is_use_local=False,
)
print("succ")
return readClickHouse(self.output_table)
class LinearDML:
def __init__(
self,
model_y="ols",
model_t="ols",
fit_cate_intercept=True,
discrete_treatment=True,
categories=[0, 1],
cv=3,
treatment_featurizer="",
):
self.model_y = model_y
self.model_t = model_t
self.treatment_featurizer = ""
self.effect_ph = ""
self.marginal_effect_ph = ""
self.cv = cv
if treatment_featurizer != "":
self.treatment_featurizer = treatment_featurizer[0]
self.effect_ph = treatment_featurizer[1]
self.marginal_effect_ph = treatment_featurizer[2]
self.sql_instance = SqlGateWayConn.create_default_conn()
def fit(self, df, Y, T, X, W=""):
self.table = df.getTableName()
self.Y = Y
self.T = T
self.X = X
self.dml_sql = self.get_dml_sql(
self.table,
Y,
T,
X,
W,
self.model_y,
self.model_t,
self.cv,
self.treatment_featurizer,
)
self.forward_sql = self.sql_instance.sql(
sql=self.dml_sql, is_calcite_parse=True
)
self.result = self.sql_instance.sql(self.dml_sql)
if isinstance(self.result, str) and self.result.find("error") != -1:
self.success = False
self.ols = self.result
return
self.ols = Ols(self.result["final_model"][0])
if isinstance(self.ols, Ols) == True:
self.success = True
else:
self.success = False
def get_dml_sql(
self, table, Y, T, X, W, model_y, model_t, cv, treatment_featurizer
):
sql = "select linearDML("
sql += Y + "," + T + "," + X + ","
if W.strip() != "":
sql += W + ","
sql += (
"model_y='"
+ model_y
+ "',"
+ "model_t='"
+ model_t
+ "',"
+ "cv="
+ str(cv)
)
sql += " ) from " + table
sql = sql.replace("ols", "Ols")
return sql
def __str__(self):
return str(self.ols)
def summary(self):
if not self.success:
return str(self.ols)
return self.ols.get_dml_summary()
def exchange_dml_sql(self, sql, use_interval=False):
pos = sql.find("final_model")
if pos == -1:
raise Exception("Logical Error: final_model not found in sql")
sql = sql[0 : pos + len("final_model")]
pos = sql.rfind("Ols")
if pos == -1:
raise Exception("Logical Error: Ols not found in sql")
if use_interval:
sql = sql[0:pos] + "OlsIntervalState" + sql[pos + len("Ols") :]
else:
sql = sql[0:pos] + "OlsState" + sql[pos + len("Ols") :]
return sql
def effect(self, df, X="", T0=0, T1=1):
table_predict = df.getTableName()
table_output = DataFrame.createTableName()
if table_predict == "":
table_predict = self.table
if self.success == False:
return str(self.ols)
if not X:
X = self.X
sql = self.exchange_dml_sql(self.forward_sql) + "\n"
X = X.replace("+", ",")
X1 = X.split(",")
X1 = [x + " as " + x for x in X1]
sql += "select "
if table_output == "":
sql += self.Y + " as Y," + self.T + " as T,"
sql += (
X
+ ", evalMLMethod(final_model, "
+ X
+ ", "
+ str(T1 - T0)
+ ") as predict from "
+ table_predict
)
if table_output != "":
if X.count("+") >= 30 or X.count(",") >= 30:
print("The number of x exceeds the limit 30")
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_output,
sql_statement=sql,
primary_column="predict",
is_force_materialize=True,
is_sql_complete=True,
sql_table_name=self.table,
is_use_local=True,
)
return readClickHouse(table_output)
def ate(self, X="", T0=0, T1=1):
if not self.success:
return str(self.ols)
if not X:
X = self.X
sql = self.exchange_dml_sql(self.forward_sql) + "\n"
X = X.replace("+", ",")
sql += (
"select avg(evalMLMethod(final_model, "
+ X
+ ", "
+ str(T1 - T0)
+ ")) from "
+ self.table
)
sql += " limit 100"
t = self.sql_instance.sql(sql)
return t
def effect_interval(self, df, X="", T0=0, T1=1, alpha=0.05):
table_output = DataFrame.createTableName()
if not self.success:
return str(self.ols)
if not X:
X = self.X
sql = self.exchange_dml_sql(self.forward_sql, use_interval=True) + "\n"
X = X.replace("+", ",")
X1 = X.split(",")
X1 = [x + " as " + x for x in X1]
sql += (
"select "
+ ", ".join(X1)
+ ", evalMLMethod(final_model,'confidence',"
+ str(1 - alpha)
+ ", "
+ X
+ ", "
+ str(T1 - T0)
+ ") as predict from "
+ df.getTableName()
)
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_output,
sql_statement=sql,
primary_column="predict",
is_force_materialize=True,
is_sql_complete=True,
sql_table_name=self.table,
is_use_local=True,
)
return readClickHouse(table_output)
def ate_interval(self, X="", T0=0, T1=1, alpha=0.05):
if not self.success:
return str(self.ols)
if not X:
X = self.X
sql = self.exchange_dml_sql(self.forward_sql, use_interval=True) + "\n"
X = X.replace("+", ",")
Xs = X.split(",")
for i in range(len(Xs)):
Xs[i] = "avg(" + Xs[i] + ")"
X = ",".join(Xs)
sql += (
"select evalMLMethod(final_model,'confidence',"
+ str(1 - alpha)
+ ", "
+ X
+ ", "
+ str(T1 - T0)
+ ") from "
+ self.table
)
sql += " limit 100"
t = self.sql_instance.sql(sql)
if str(t).find("DB::Exception") != -1:
return t
s = "mean_point\tci_mean_lower\tci_mean_upper\t\n"
t = t.iloc[0, 0]
t = t[1:-1]
t = t.split(",")
for i in range(len(t)):
s += str(round(float(t[i]), 10)) + "\t"
return s
def get_sql(self, X):
sql = self.exchange_dml_sql(self.forward_sql) + "\n"
x_with_space = X.split("+")
x_with_space.append("1")
model_arguments = ""
for x in x_with_space:
for y in self.treatment_featurizer:
model_arguments += x + "*" + y + ","
model_arguments = "(False)(" + self.Y + "," + model_arguments[0:-1] + ")"
last_olsstate_index = sql.rfind("OlsState")
last_from_index = sql.rfind("FROM")
sql = (
sql[: last_olsstate_index + len("OlsState")]
+ model_arguments
+ " "
+ sql[last_from_index:]
)
tmp_eval = "evalMLMethod(final_model"
for x in X.split("+"):
for y in self.marginal_effect_ph:
tmp_eval += "," + str(x) + "*" + str(y)
for y in self.marginal_effect_ph:
tmp_eval += "," + str(y)
tmp_eval += ")"
evals = []
for i in range(0, len(self.marginal_effect_ph) + 1):
evals.append(tmp_eval)
for i in range(1, len(self.marginal_effect_ph) + 1):
for j in range(0, len(evals)):
if i != j:
evals[j] = evals[j].replace("@PH" + str(i), "0")
else:
evals[j] = evals[j].replace("@PH" + str(i), "1")
sql_const = sql + " select "
sql_effect = sql_const
sql_ate = sql_const + " avg( "
for i in range(1, len(self.marginal_effect_ph) + 1):
sql_const += evals[i] + " - " + evals[0] + " as predict" + str(i) + ","
sql_effect += evals[i] + " - " + evals[0] + " +"
sql_ate += evals[i] + " - " + evals[0] + " +"
sql_const = sql_const[:-1] + " from " + self.table
sql_effect = sql_effect[:-1] + " as predict from " + self.table
sql_ate = sql_ate[:-1] + ") as predict from " + self.table
return [sql_const, sql_effect, sql_ate]
def const_marginal_effect(self, X="", table_output=""):
if not self.success:
return str(self.ols)
if not X:
X = self.X
if self.marginal_effect_ph == "":
return "Error: treatment featurizer is empty!"
sql = self.get_sql(X)[0]
if table_output == "":
sql += " limit 100"
if table_output != "":
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_output,
sql_statement=sql,
primary_column="predict1",
is_force_materialize=True,
is_sql_complete=True,
sql_table_name=self.table,
)
return
t = self.sql_instance.sql(sql)
return t
def marginal_effect(self, X="", table_output=""):
if not self.success:
return str(self.ols)
if not X:
X = self.X
if not self.marginal_effect_ph:
return "Error: treatment featurizer is empty!"
sql = self.get_sql(X)[1]
if table_output == "":
sql += " limit 100"
if table_output != "":
if sql.count("+") >= 30 or sql.count(",") >= 30:
print("The number of x exceeds the limit 40")
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_output,
sql_statement=sql,
primary_column="predict",
is_force_materialize=True,
is_sql_complete=True,
sql_table_name=self.table,
)
t = self.sql_instance.sql(sql)
return t
def marginal_ate(self, X="", table_output=""):
if not self.success:
return str(self.ols)
if not X:
X = self.X
if not self.marginal_effect_ph:
return "Error: treatment featurizer is empty!"
sql = self.get_sql(X)[2]
if not table_output:
sql += " limit 100"
if table_output:
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_output,
sql_statement=sql,
primary_column="predict",
is_force_materialize=True,
is_sql_complete=True,
sql_table_name=self.table,
)
return
t = self.sql_instance.sql(sql)
return t
def load_chmodel_predict(df_model, df, keep_col="*"):
table_input = df_2_table(df)
model_table = df_2_table(df_model)
sql_instance = SqlGateWayConn.create_default_conn()
tmp = sql_instance.sql(
f"select cutbin_string,effect_string from {model_table} limit 1"
)
cutbinstring = tmp["cutbin_string"][0]
effect_string = tmp["effect_string"][0]
table_tmp = f"{table_input}_{int(time.time())}_foreffect_2_clickhouse"
table_output = f"tmp_{int(time.time())}"
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_tmp,
sql_statement=f"""
*,{cutbinstring},1 as index
""",
sql_table_name=table_input,
is_force_materialize=True,
)
ClickHouseUtils.clickhouse_create_view(
clickhouse_view_name=table_output,
sql_statement=f"""
{keep_col},
case
{effect_string}
else 4294967295
end as effect
""",
sql_table_name=table_tmp,
is_force_materialize=True,
)
ClickHouseUtils.clickhouse_drop_view(clickhouse_view_name=table_tmp)
return table_2_df(table_output)