__all__ = [
"readClickHouse",
"readStarRocks",
"readTdw",
"readSparkDf",
"readCsv",
"DataFrame",
]
import random
from typing import Tuple
import pandas
import requests
import time
from fast_causal_inference.dataframe.format import outPutFormat
from fast_causal_inference.dataframe import ais_dataframe_pb2 as DfPb
import re
import base64
from google.protobuf import json_format
from google.protobuf.json_format import Parse
import copy
from fast_causal_inference.dataframe.df_base import (
DfColumnNode,
OlapEngineType,
DfContext,
DfColumnInternalNode,
DfColumnLeafNode,
)
from fast_causal_inference.dataframe.functions import DfFnColWrapper
from fast_causal_inference.dataframe import functions as AisF
from fast_causal_inference.util.data_transformer import (
clickhouse_2_csv,
clickhouse_2_dataframe,
dataframe_2_clickhouse,
dataframe_2_starrocks,
csv_2_clickhouse,
csv_2_starrocks,
starrocks_2_dataframe,
)
from fast_causal_inference.dataframe import regression as AisRegressionF
from fast_causal_inference.dataframe import statistics as AisStatisticsF
from fast_causal_inference.util import get_user, ClickHouseUtils
from fast_causal_inference.common import get_context
import fast_causal_inference.lib.tools as AisTools
"""
syntax = "proto3";
enum ColumnType {
Unknown = 0;
String = 1;
Int = 2;
Float = 3;
Bool = 4;
Date = 5;
DateTime = 6;
Time = 7;
UUID = 8;
Array = 9;
}
enum TaskType {
TASK_TYPE_FILL_SCHEMA = 0;
}
message Column {
string name = 1;
string alias = 2;
ColumnType type = 3;
}
message Limit {
int64 limit = 1;
int64 offset = 2;
}
enum SourceType {
ClickHouse = 0;
}
message ClickHouseSource {
string table_name = 1;
string database = 2;
}
message Source {
SourceType type = 1;
ClickHouseSource clickhouse = 2;
}
message Order {
Column column = 1;
bool desc = 2;
}
message DataFrame {
repeated Column columns = 1;
repeated string filters = 2;
repeated Column group_by = 3;
repeated Order order_by = 4;
Limit limit = 5;
Source source = 6;
};
message DataFrameRequest {
DataFrame df = 1;
TaskType task_type = 2;
string rtx = 3;
int64 device_id = 4;
}
"""
def getSuperName(column):
if isinstance(column, DfPb.Column):
return column.alias if column.alias else column.name
else:
raise Exception("type error")
[docs]class DataFrame:
"""
This class is used to create a DataFrame object.
"""
def __init__(self, olap_engine=OlapEngineType.CLICKHOUSE):
datasource = dict()
PROJECT_CONF = get_context().project_conf
for cell in PROJECT_CONF["datasource"]:
datasource[cell["device_id"]] = cell
self.device_id = None
for device in datasource:
if datasource[device].get(str(olap_engine) + "_database") is not None:
self.device_id = device
self.database = datasource[device][str(olap_engine) + "_database"]
break
if self.device_id is None:
raise Exception(f"Unable to get any device of engine({olap_engine}).")
self.data = DfPb.DataFrame()
if olap_engine == OlapEngineType.STARROCKS:
self.data.source.type = "StarRocks"
self.url = (
PROJECT_CONF["sqlgateway"]["url"]
+ PROJECT_CONF["sqlgateway"]["dataframe_path"]
)
self.rtx = get_user()
self._ctx = DfContext(engine=olap_engine)
self._select_list = []
self._name_dict = {}
self._need_agg = False
def serialize(self):
return self.data.SerializeToString()
def serializeBase64(self):
return base64.b64encode(self.serialize()).decode()
def serializeJson(self):
return json_format.MessageToJson(self.data)
def deserialize(self, data):
self.data.ParseFromString(data)
[docs] def toPandas(self):
"""
This function is used to convert the result of the dataframe to pandas.DataFrame
"""
self._finalize()
self.data.result = ""
self.task_type = DfPb.TaskType.EXECUTE
self.execute()
res = outPutFormat(list(eval(self.data.result)))
return res
def _is_column(self, column, name):
if name is None:
return False
if isinstance(column, DfColumnNode):
return column.alias == name or column.sql(ctx=self._ctx) == name
raise Exception(f"type({type(column)}) is not DfColumn")
def _get_col_name(self, column):
if isinstance(column, DfColumnLeafNode):
return column.alias if column.alias else column.sql(ctx=self._ctx)
if isinstance(column, DfColumnInternalNode):
return column.alias
raise Exception(
f"type({type(column)}) is not DfColumnLeafNode|DfColumnInternalNode"
)
[docs] def printSchema(self):
"""
This function is used to print the schema of the dataframe
"""
pd = pandas.DataFrame(columns=["name", "type"])
for column in self._select_list:
super_name = self._get_col_name(column)
column_type = column.type
# if columns type is empty, set Float64
if column_type == None or column_type == "":
column_type = "Float64"
# TODO 如果 type 为空,应该先填充好
pd = pd.append({"name": super_name, "type": column_type}, ignore_index=True)
print(pd)
def __getitem__(self, key):
for column in self._select_list:
if column.alias == key or column.sql(self._ctx) == key:
return AisF.col(column)
raise Exception("column %s not found" % key)
def debug(self):
self._finalize()
print(self.data.__str__())
all_names = dict()
for name in self._name_dict:
all_names[name] = self._name_dict[name].sql(self._ctx)
print(all_names)
def _finalize(self):
del self.data.columns[:]
for col in self.data.group_by:
self.data.columns.append(col)
for col in self._select_list:
self.data.columns.append(
DfPb.Column(name=col.sql(ctx=self._ctx), alias=col.alias)
)
def __str__(self):
self._finalize()
res = self.toPandas()
if res.shape[0] == 1 and res.shape[1] == 1:
return res.values[0][0]
return res.__str__()
def __repr__(self):
return self.__str__()
[docs] def first(self):
"""
This function is used to get the first row of the dataframe
>>> df.first()
"""
new_df = copy.deepcopy(self)
new_df.data.limit.limit = 1
return new_df
[docs] def head(self, n):
"""
This function is used to get the first n rows of the dataframe
>>> df.head(3)
"""
new_df = copy.deepcopy(self)
new_df.data.limit.limit = n
return new_df
[docs] def take(self, n):
"""
This function is used to get the first n rows of the dataframe
>>> df.head(3)
"""
new_df = copy.deepcopy(self)
new_df.data.limit.limit = n
return new_df
[docs] def show(self):
"""
Prints the DataFrame, equivalent to print(dataframe).
>>> df.head(3).show()
"""
print(self.__str__())
# 增加检验逻辑
# 这里用的是 alias 的名字,需要替换回来, 下面的 groupBy, orderBy 也是一样
[docs] def where(self, filter):
"""
Filters rows using the given condition.
>>> df.where("column1 > 1").show()
"""
new_df = copy.deepcopy(self)
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", filter) != None:
check_exist = False
for column in new_df._select_list:
super_name = self._get_col_name(column)
if super_name in filter:
check_exist = True
break
if not check_exist:
raise Exception("the column of filter is not exist" % filter)
new_df.data.filters.append(filter)
return new_df
[docs] def filter(self, filter):
"""
Alias for the 'where' function. Filters rows using the given condition.
>>> df.filter("column1 > 1").show()
"""
return self.where(filter)
@staticmethod
def _expand_args(*args):
expand_args = []
for arg in args:
if isinstance(arg, str):
expand_args.append(arg)
elif isinstance(arg, list):
for sub_arg in arg:
expand_args.append(sub_arg)
return expand_args
[docs] def select(self, *args):
"""
Selects specified columns from the DataFrame and returns a new DataFrame.
>>> new_df = df.select('column1', 'column2')
>>> new_df = df.select(['column1', 'column2'])
"""
args = DataFrame._expand_args(*args)
new_df = copy.deepcopy(self)
new_df.checkColumn(*args)
new_select_list = []
for arg in args:
new_select_list.append(new_df._expand_expr(arg))
new_df._select_list = new_select_list
return new_df
[docs] def drop(self, *args):
"""
Drops specified columns from the DataFrame and returns a new DataFrame.
>>> new_df = df.drop('column1', 'column2')
>>> new_df = df.drop(['column1', 'column2'])
"""
args = DataFrame._expand_args(*args)
new_df = copy.deepcopy(self)
new_df.checkColumn(*args)
for i in range(len(new_df._select_list) - 1, -1, -1):
name = self._get_col_name(new_df._select_list[i])
is_exist = False
for arg in args:
if name == arg:
is_exist = True
break
if is_exist:
del new_df._select_list[i]
return new_df
[docs] def withColumn(self, new_column, func):
"""
This function adds a new column to the DataFrame.
Example
--------
.. code-block:: python
import fast_causal_inference.dataframe.functions as Fn
import fast_causal_inference.dataframe.statistics as S
df1 = df.select('x1','x2','x3', 'numerator')
# Method 1: Select columns through df's index: df['col']
df1 = df1.withColumn('new_col', Fn.sqrt(df1['numerator']))
df1.show()
# Method 2: Select columns directly through string: Fn.col('new_col')
df1 = df1.withColumn('new_col2', Fn.pow('new_col', 2))
df1.show()
df1 = df1.withColumn('new_col3', Fn.col('new_col') * Fn.col('new_col'))
df1.show()
# Add constant
df2 = df1.withColumn('c1', Fn.lit(1))
df2 = df2.withColumn('c2', Fn.lit('1'))
df2.show()
# Nesting
df2 = df1.withColumn('c1', Fn.pow(Fn.sqrt(Fn.sqrt(Fn.col('x1'))), 2))
df2.show()
# +-*/% operations
df2 = df1.withColumn('c1', 22 + df1['x1'] / 2 + 2 / df1['x2'] * df1['x3'] % 2 - (df1['x2']))
df2.show()
df2 = df1.withColumn('c1', Fn.col('x1 + x2 * x3 + x3'))
df2.show()
# if
df2 = df1.withColumn('cc1', 'if(x1 > 0, 1, -1)')
df2.show()
df2 = df1.withColumn('cc1', Fn.If('x1 > 0',1,-1))
df2.show()
"""
new_df = copy.deepcopy(self)
if isinstance(func, str):
func = new_df._add_space(func)
for alias in new_df._name_dict:
if alias is None or alias == "":
continue
func = func.replace(
" " + alias + " ",
" " + new_df._name_dict[alias].sql(new_df._ctx) + " ",
)
func = func.replace(" ", "")
new_df._select_list.append(DfColumnLeafNode(func, new_column))
new_df._name_dict[new_column] = new_df._select_list[-1]
new_df._name_dict[
new_df._select_list[-1].sql(new_df._ctx)
] = new_df._select_list[-1]
elif isinstance(func, DfFnColWrapper):
new_df = new_df._apply_func(func.alias(new_column), keep_old_cols=True)
return new_df
@classmethod
def _add_space(cls, text):
for ch in [",", "+", "-", "*", "/", ")", "(", "=", ">", "<", "!"]:
text = text.replace(ch, " " + ch + " ")
return " " + text + " "
def _expand_expr(self, expr):
if not isinstance(expr, str):
raise Exception(f"Logical Error: expr(`{expr}`) is expected to be str.")
expr = self._add_space(expr)
origin_expr = expr
for alias in self._name_dict:
if alias is None or alias == "":
continue
expr = expr.replace(
" " + alias + " ", " " + self._name_dict[alias].sql(self._ctx) + " "
)
expr = expr.replace(" ", "")
origin_expr = origin_expr.replace(" ", "")
alias = origin_expr if origin_expr != expr else None
self._name_dict[origin_expr] = DfColumnLeafNode(expr, alias=alias)
self._name_dict[expr] = DfColumnLeafNode(expr, alias=alias)
return DfColumnLeafNode(expr, alias=alias)
[docs] def withColumnRenamed(self, old_name, new_name):
"""
Returns a new DataFrame by renaming an existing column.
>>> df.withColumnRenamed("column1", "new_column1").show()
"""
new_df = copy.deepcopy(self)
new_df.checkColumn(old_name)
for column in new_df._select_list:
if column.alias == old_name or self._get_col_name(column) == old_name:
column.alias = new_name
return new_df
[docs] def describe(self, cols="*"):
"""
Returns the summary statistics for the columns in the DataFrame.
>>> df.describe()
>>> df.describe(['x1','x2'])
# count avg std min quantile_0.25 quantile_0.5 quantile_0.75 quantile_0.90 quantile_0.99 max
# x1 10000.0 -0.018434 0.987606 -3.740101 -0.68929 -0.028665 0.654210 1.274144 2.321097 3.80166
# x2 10000.0 0.021976 1.986209 -8.893264 -1.28461 0.015400 1.357618 2.583523 4.725829 7.19662
"""
self = self.materializedView()
table_name = self.getTableName()
return AisTools.describe(table_name, cols)
[docs] def sample(self, fraction):
"""
This function samples a fraction of rows without replacement from the DataFrame.
>>> df1 = df.sample(1000)
>>> df1.count().show()
>>> df2 = df.sample(0.5)
>>> df2.count().show()
"""
new_df = self.materializedView(is_temp=True, is_physical_table=True)
row_num = int(new_df.count().__str__())
if fraction <= 1 and fraction >= 0:
pass
elif fraction > 1:
if fraction > row_num:
raise Exception("fraction should be less than row number")
fraction = fraction / row_num
else:
raise Exception("fraction should be in [0, 1] or > 1")
new_df_name = DataFrame.createTableName()
temp_df = new_df
temp_df = temp_df.where(f"rand() / pow(2, 32) < {fraction}")
return temp_df.materializedView(is_physical_table=True)
[docs] def split(self, test_size=0.5):
"""
This function splits the DataFrame into two DataFrames.
>>> df_train, df_test = df.split(0.5)
>>> print(df_train.count())
>>> print(df_test.count())
"""
new_df = self.materializedView(is_temp=True)
new_table_name = AisTools.data_split(new_df.getTableName(), test_size)
return readClickHouse(new_table_name[0]), readClickHouse(new_table_name[1])
[docs] def orderBy(self, *args):
"""
Orders the DataFrame by the specified columns and returns a new DataFrame.
>>> import fast_causal_inference.dataframe.functions as Fn
>>> new_df = df.orderBy('column1', Fn.desc('column2'))
>>> new_df = df.orderBy(['column1', 'column2'])
"""
args = DataFrame._expand_args(*args)
# clear self.data.order_by
new_df = copy.deepcopy(self)
del new_df.data.order_by[:]
for arg in args:
if isinstance(arg, str):
order = DfPb.Order()
order.column.name = arg
new_df.data.order_by.append(order)
elif isinstance(arg, DfPb.Order):
new_df.data.order_by.append(arg)
return new_df
[docs] def groupBy(self, *args):
"""
Groups the DataFrame by the specified columns and returns a new DataFrame.Exception: If the DataFrame is already in `need_agg` state.
>>> new_df = df.groupBy('column1', 'column2')
>>> new_df = df.groupBy(['column1', 'column2'])
"""
args = DataFrame._expand_args(*args)
if self._need_agg:
raise Exception(
"Dataframe is already in `need_agg` state, cannot apply group by."
)
new_df = self._transform_to_nested()
new_df._need_agg = True
for arg in args:
if arg not in self._name_dict:
raise Exception(f"Unable to find column named `{arg}`.")
column = DfPb.Column()
column.name = arg
new_df.data.group_by.append(column)
return new_df
def agg(self, *args):
# 支持如下两种语法
# df.agg(avg("A"), sum("B").alias).show()
# df.agg({"A": "avg", "B": "sum"}).show()
new_df = copy.deepcopy(self)
if len(args) == 0:
raise Exception("Nothing to agg.")
if len(args) == 1 and isinstance(args[0], dict):
args = list(map(lambda key: getattr(AisF, args[0][key])(key), args[0]))
if not all(map(lambda arg: isinstance(arg, DfFnColWrapper), args)):
raise Exception(f"{args} is not List[DfFnColWrapper].")
new_df = new_df._apply_func(*args)
return new_df
# aggregate functions
def sum(self, column):
return self._apply_func(AisF.sum(column))
def avg(self, column):
return self._apply_func(AisF.avg(column))
def count(self, *, expr="*"):
return self._apply_func(AisF.count(expr=expr))
def max(self, column):
return self._apply_func(AisF.max(column))
def min(self, column):
return self._apply_func(AisF.min(column))
def any(self, column):
return self._apply_func(AisF.any(column))
def stddevPop(self, column):
return self._apply_func(AisF.stddevPop(column))
def stddevSamp(self, column):
return self._apply_func(AisF.stddevSamp(column))
def varPop(self, column):
return self._apply_func(AisF.varPop(column))
def varSamp(self, column):
return self._apply_func(AisF.varSamp(column))
def corr(self, x, y):
return self._apply_func(AisF.corr(x, y))
def covarPop(self, x, y):
return self._apply_func(AisF.covarPop(x, y))
def covarSamp(self, x, y):
return self._apply_func(AisF.covarSamp(x, y))
def anyLast(self, x, y):
return self._apply_func(AisF.anyLast(x, y))
def anyMin(self, x, y):
return self._apply_func(AisF.anyMin(x, y))
def anyMax(self, x, y):
return self._apply_func(AisF.anyMax(x, y))
def kolmogorov_smirnov_test(self, x, y):
return self._apply_func(AisStatisticsF.kolmogorov_smirnov_test(x, y))
def student_ttest(self, x, y):
return self._apply_func(AisStatisticsF.student_ttest(x, y))
def welch_ttest(self, x, y):
return self._apply_func(AisStatisticsF.welch_ttest(x, y))
def mean_z_test(
self,
sample_data,
sample_index,
population_variance_x,
population_variance_y,
confidence_level,
):
return self._apply_func(
AisStatisticsF.mean_z_test(
sample_data,
sample_index,
population_variance_x,
population_variance_y,
confidence_level,
)
)
def quantile(self, x, level=None):
if level is None:
raise Exception("param `level` is not set")
return self._apply_func(AisF.quantile(x, level=level))
# all in sql functions
def delta_method(self, column, std="True"):
return self._apply_func(AisStatisticsF.delta_method(expr=column, std=std))
def ttest_1samp(self, Y, alternative="two-sided", mu=0, X=""):
return self._apply_func(AisStatisticsF.ttest_1samp(Y, alternative, mu, X))
def ttest_2samp(self, Y, index, alternative="two-sided", X="",pse=''):
return self._apply_func(
AisStatisticsF.ttest_2samp(Y, index, alternative=alternative, X=X,pse=pse)
)
def xexpt_ttest_2samp(
self,
numerator,
denominator,
index,
uin,
metric_type="avg",
group_buckets="[1:1]",
alpha=0.05,
MDE=0.005,
power=0.8,
X="",
):
return self._apply_func(
AisStatisticsF.xexpt_ttest_2samp(
numerator,
denominator,
index,
uin,
metric_type,
group_buckets,
alpha,
MDE,
power,
X,
)
)
def mann_whitney_utest(
self,
sample_data,
sample_index,
alternative="two-sided",
continuity_correction=1,
):
return self._apply_func(
AisStatisticsF.mann_whitney_utest(
sample_data, sample_index, alternative, continuity_correction
)
)
def srm(self, x, groupby, ratio="[1,1]"):
return self._apply_func(AisStatisticsF.srm(x, groupby, ratio))
def ols(self, column, use_bias=True):
return self._apply_func(AisRegressionF.ols(column, use_bias))
def wls(self, column, weight, use_bias=True):
return self._apply_func(AisRegressionF.wls(column, weight, use_bias))
def stochastic_linear_regression(
self, expr, learning_rate=0.00001, l1=0.1, batch_size=15, method="SGD"
):
return self._apply_func(
AisRegressionF.stochastic_linear_regression(
expr, learning_rate, l1, batch_size, method
)
)
def stochastic_logistic_regression(
self, expr, learning_rate=0.00001, l1=0.1, batch_size=15, method="SGD"
):
return self._apply_func(
AisRegressionF.stochastic_logistic_regression(
expr, learning_rate, l1, batch_size, method
)
)
def matrix_multiplication(self, *args, std=False, invert=False):
return self._apply_func(
AisStatisticsF.matrix_multiplication(*args, std=std, invert=invert)
)
def did(self, Y, treatment, time, *X):
return self._apply_func(AisRegressionF.did(Y, treatment, time, *X))
def iv_regression(self, formula):
return self._apply_func(AisRegressionF.iv_regression(formula))
def boot_strap(self, func, sample_num, bs_num):
return self._apply_func(AisStatisticsF.boot_strap(func, sample_num, bs_num))
def permutation(self, func, permutation_num, *col):
return self._apply_func(AisStatisticsF.permutation(func, permutation_num, *col))
def checkColumn(self, *args, throw_exception=True):
for arg in args:
if arg not in self._name_dict:
if throw_exception == True:
raise Exception("column %s not found" % arg)
else:
return False
return True
# 优先提取 alias
def getAliasOrName(self):
alias_or_name = []
for column in self._select_list:
alias_or_name.append(self._get_col_name(column))
return alias_or_name
def _find_column(self, name):
if self._name_dict.get(name) is not None:
# print(f"find column `{name}`: {self._name_dict.get(name).sql(self._ctx)}")
return self._name_dict.get(name)
for column in self._select_list:
col = column.find_column(name)
if col is not None:
# print(f"find column `{name}`: {col.sql(self._ctx)}")
return col
# print(f"cannot find column `{name}`, using self.")
return DfColumnLeafNode(name)
def _unwrap(self, fn_wrapper: DfFnColWrapper):
if not isinstance(fn_wrapper, DfFnColWrapper):
raise Exception(f"func({type(DfFnColWrapper)}) should be DfFnColWrapper!")
fn, params, cols = fn_wrapper.fn, fn_wrapper.params, fn_wrapper.columns
alias = fn.alias
columns = []
for col in cols:
if isinstance(col, int) or isinstance(col, float) or isinstance(col, bool):
from_col = DfColumnLeafNode(str(col))
elif isinstance(col, str):
from_col = copy.deepcopy(self._find_column(col))
self._name_dict[col] = from_col
elif isinstance(col, DfColumnNode):
from_col = copy.deepcopy(col)
if col.alias is not None:
self._name_dict[col.alias] = from_col
elif isinstance(col, DfFnColWrapper):
from_col = self._unwrap(col)
if from_col.alias is not None:
self._name_dict[from_col.alias] = from_col
else:
raise Exception(
f"type of col({type(col)}) is neither str nor DfColumn."
)
if not isinstance(from_col, DfColumnNode):
# we can only apply function on DfColumn but no others
raise Exception(f"Logical Error: {type(from_col)} is not DfColumn.")
self._name_dict[from_col.sql(self._ctx)] = from_col
columns.append(from_col)
new_col = DfColumnInternalNode(fn, params, columns, alias)
if alias:
self._name_dict[alias] = new_col
self._name_dict[new_col.sql(self._ctx)] = new_col
return new_col
def _apply_func(self, *fn_wrappers: Tuple[DfFnColWrapper], keep_old_cols=False):
if any(map(lambda fn_wrapper: fn_wrapper.has_agg_func(), fn_wrappers)):
if not self._need_agg:
new_df = self._transform_to_nested()
else:
new_df = copy.deepcopy(self)
new_df._need_agg = False
else:
new_df = copy.deepcopy(self)
if not keep_old_cols:
new_df._select_list = []
for fn_wrapper in fn_wrappers:
if not isinstance(fn_wrapper, DfFnColWrapper):
raise Exception("fn_wrapper should be a DfFnColWrapper object!")
new_col: DfColumnNode = new_df._unwrap(fn_wrapper)
new_df._select_list.append(new_col)
return new_df
def _set_cte(self, cte):
self.data.cte = cte
def _transform_to_nested(self):
new_df = copy.deepcopy(self)
subquery = new_df.getExecutedSql()
del new_df.data.columns[:]
del new_df.data.filters[:]
del new_df.data.group_by[:]
del new_df.data.order_by[:]
new_df.data.limit.limit = 0
del new_df._select_list[:]
new_df.data.cte = ""
for col_name in new_df._name_dict:
new_df._name_dict[col_name] = DfColumnLeafNode(col_name)
if new_df.data.source.type == DfPb.SourceType.ClickHouse:
new_df.data.source.clickhouse.table_name = subquery
new_df.data.source.clickhouse.database = "Nested"
elif new_df.data.source.type == DfPb.SourceType.StarRocks:
new_df.data.source.starrocks.table_name = subquery
new_df.data.source.starrocks.database = "Nested"
else:
raise Exception("not support source type")
return new_df
[docs] def union(self, df):
"""
This function is used to union two DataFrames.
The two DataFrames must have the same number of columns, and the columns must have the same names and order.
>>> df1 = df1.union(df2)
"""
if len(self._select_list) != len(df._select_list):
raise Exception("the length of select list not match")
for i in range(len(self._select_list)):
if self._select_list[i].alias != df._select_list[i].alias:
raise Exception("the column name of select list not match")
if self._select_list[i].alias == None and self._select_list[i].sql(
self._ctx
) != df._select_list[i].sql(df._ctx):
raise Exception("the column name of select list not match")
new_df = copy.deepcopy(self)
subquery1 = self.getExecutedSql()
subquery2 = df.getExecutedSql()
if new_df.data.source.type == DfPb.SourceType.ClickHouse:
new_df.data.source.clickhouse.table_name = f"({subquery1}) union all ({subquery2})"
new_df.data.source.clickhouse.database = "Nested"
elif new_df.data.source.type == DfPb.SourceType.StarRocks:
new_df.data.source.starrocks.table_name = f"({subquery1}) union all ({subquery2})"
new_df.data.source.starrocks.database = "Nested"
new_df._select_list = []
for col in self._select_list:
new_df._select_list.append(new_df._expand_expr(self._get_col_name(col)))
return new_df
# parse executed sql, remove limit
def getExecutedSql(self):
self.__str__()
sql = self.data.execute_sql
# print(sql)
sql = re.sub(r"limit\s+\d+\s*$", "", sql, flags=re.IGNORECASE)
return sql
# is_temp = True: 每天凌晨 2 点自动删除
def materializedView(
self,
is_physical_table=False,
is_distributed_create=True,
is_temp=False,
table_name=None,
):
materialized_sql = self.getExecutedSql()
if table_name is not None:
new_df_name = table_name
else:
new_df_name = DataFrame.createTableName(is_temp)
ClickHouseUtils.clickhouse_create_view_v2(
table_name=new_df_name,
select_statement=materialized_sql,
origin_table_name=self.getTableName(),
is_physical_table=is_physical_table,
is_distributed_create=is_distributed_create,
)
return readClickHouse(new_df_name)
def createDfRequest(self):
df_req = DfPb.DataFrameRequest()
df_req.df.CopyFrom(self.data)
df_req.task_type = self.task_type
df_req.rtx = self.rtx
df_req.device_id = self.device_id
df_req.database = self.database
return df_req
def createDataFrameRequestBase64(self):
df_req = self.createDfRequest()
return base64.b64encode(df_req.SerializeToString()).decode()
def createDfRequestJson(self):
df_req = self.createDfRequest()
return json_format.MessageToJson(df_req)
def createCurlRequest(self):
return (
'curl -H "Content-Type: application/json" -X POST -d \''
+ self.createDfRequestJson()
+ "' "
+ self.url
+ "/json"
)
def execute(self, retry_times=1):
logger = get_context().logger
while retry_times > 0:
try:
json_body = self.createDfRequestJson()
logger.debug("url= " + self.url + ",data= " + json_body)
resp = requests.post(
self.url,
data=json_body.encode("utf-8"),
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
)
logger.debug("response=" + resp.text)
df_resp = Parse(resp.text, DfPb.DataFrameResponse())
if df_resp.status == DfPb.RetStatus.FAIL:
print("Error: ")
print(df_resp.msg)
elif df_resp.status == DfPb.RetStatus.SUCC:
self.data = df_resp.df
if not self._select_list:
self._select_list = [
DfColumnLeafNode(
column_name=col.name, alias=col.alias, type=col.type
)
for col in df_resp.df.columns
]
for col in self._select_list:
self._name_dict[col.sql(self._ctx)] = col
return
except Exception as e:
time.sleep(1)
retry_times -= 1
if retry_times == 0:
raise e
@staticmethod
def createTableName(is_temp=False):
table_name = "df_table_"
if is_temp:
table_name += "temp_"
table_name += time.strftime("%Y%m%d%H%M%S", time.localtime()) + "_"
table_name += str(random.randint(100000, 999999))
return table_name
def fill_column_info(self):
self.task_type = DfPb.TaskType.FILL_SCHEMA
self.execute()
# 如果有嵌套子查询,获取最内层的表名
def getTableName(self):
return self.data.source.clickhouse.table_name
[docs] def toCsv(self, csv_file_abs_path):
"""
Convert the data from ClickHouse table to a CSV file.
>>> df.toCsv("/path/to/output.csv")
"""
clickhouse_table_name = self.data.source.clickhouse.table_name
clickhouse_2_csv(clickhouse_table_name, csv_file_abs_path)
[docs] def toTdw(
self,
tdw_database,
tdw_table,
tdw_user=None,
tdw_passward=None,
group="tl",
is_drop_table=False,
overwrite=True,
priPart=None,
):
"""
ClickHouse table >> TDW-thive table.
Parameters
----------
:tdw_database (str): The name of the TDW database.
:tdw_table (str): The name of the TDW table.
:tdw_user (str, optional): The username for TDW. Default is None.
:tdw_passward (str, optional): The password for TDW. Default is None.
:group (str, optional): The group for TDW. Default is 'tl'.
:is_drop_table (bool, optional): Whether to drop the existing TDW table. Default is False.
:overwrite (bool, optional): Whether to overwrite the existing data in the TDW table. Default is True.
:priPart (str, optional): The primary partition for the TDW table. Default is None.
Example
----------
>>> df.toTdw("tdw_database", "tdw_table", group='tl')
>>> df.toTdw("tdw_database", "tdw_table", group='tl',priPart=['p_20220222'])
"""
view = self.materializedView()
from fast_causal_inference.util.clickhouse_utils import ClickHouseUtils
ClickHouseUtils.clickhouse_2_tdw_v2(
view.getTableName(),
tdw_database,
tdw_table,
tdw_user,
tdw_passward,
group,
is_drop_table,
overwrite,
priPart,
)
[docs] def toSparkDf(self):
"""
ClickHouse table >> spark dataframe.
Example
----------
>>> import fast_causal_inference
>>> fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
>>> spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
>>> df = fast_causal_inference.readClickHouse('test_data_small')
>>> spark_df = df.toSparkDf()
"""
session = getSparkSession()
if self._ctx.engine() == OlapEngineType.CLICKHOUSE:
return clickhouse_2_dataframe(
session, self.data.source.clickhouse.table_name
)
elif self._ctx.engine() == OlapEngineType.STARROCKS:
return starrocks_2_dataframe(session, self.data.source.starrocks.table_name)
else:
raise Exception(f"Olap engine `{self._ctx.engine()}` not supported.")
[docs] def toClickHouse(self, clickhouse_table_name):
"""
ClickHouse table >> ClickHouse table.
Example
----------
>>> df.toClickHouse("new_table")
"""
clickhouse_table_name = clickhouse_table_name
self.materializedView(is_physical_table=True, table_name=clickhouse_table_name)
def getSparkSession():
from fast_causal_inference import get_context
if get_context().spark_session == None:
print("Spark Session is None, pless init")
return
return get_context().spark_session
# 出入仓
[docs]def readClickHouse(table_name):
"""
Read data from a ClickHouse table into a DataFrame.
>>> import fast_causal_inference
>>> df = fast_causal_inference.readClickHouse("test_data_small")
"""
df = DataFrame()
df.data.source.clickhouse.table_name = table_name
df.data.source.clickhouse.database = df.database
df.fill_column_info()
return df
[docs]def readStarRocks(table_name):
"""
Read data from a StarRocks table into a DataFrame.
>>> import fast_causal_inference
>>> df = fast_causal_inference.readStarRocks("test_data_small")
"""
df = DataFrame(olap_engine=OlapEngineType.STARROCKS)
df.data.source.starrocks.table_name = table_name
df.data.source.starrocks.database = df.database
df.fill_column_info()
return df
def readTdw(
db,
table,
group="tl",
tdw_user=None,
tdw_passward=None,
priParts=None,
str_replace="-1",
numeric_replace=0,
olap="clickhouse",
):
"""
Read data from a TDW-thive table into a DataFrame.
Parameters
----------
:param db: The name of the TDW database.
:param table: The name of the TDW table.
:param group: The group for TDW. Default is 'tl'.
:param user: The username for TDW. Default is None.
:param passwd: The password for TDW. Default is None.
:param str_replace: The replacement for string NA values. Default is "-1".
:param numeric_replace: The replacement for numeric NA values. Default is 0.
Example
-------
::
# thive-第一种方式 tdw thive > clickhouse
import fast_causal_inference
fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
df = fast_causal_inference.readTdw("db_name", "table_name", group="tl") # 读普通表
df = fast_causal_inference.readTdw("db_name", "table_name", group="tl", priPart=['p_20220222']) # 读分区表
# thive-第二种方式 tdw thive > spark > clickhouse
import fast_causal_inference
fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
from pytoolkit import TDWSQLProvider
tdw = TDWSQLProvider(spark, group="tl",db="db_name")
spark_df = tdw.table(tblName="allinsql_test_data_small")# 读普通表
spark_df = tdw.table(tblName="allinsql_test_data_small", priPart=['p_20220222'])# 读分区表
spark_df.count()
df_ch = fast_causal_inference.readSparkDf(spark_df)
"""
session = get_context().spark_session
from pytoolkit import TDWSQLProvider
tdw = TDWSQLProvider(
session, group=group, db=db, user=tdw_user, passwd=tdw_passward
)
df_new = tdw.table(tblName=table, priParts=priParts)
df_new = AisTools.preprocess_na(df_new, str_replace, numeric_replace)
df = DataFrame()
table_name = DataFrame.createTableName()
if olap.lower() == "clickhouse":
dataframe_2_clickhouse(dataframe=df_new, clickhouse_table_name=table_name)
return readClickHouse(table_name)
elif olap.lower() == "starrocks":
dataframe_2_starrocks(dataframe=df_new, starrocks_table_name=table_name)
return readStarRocks(table_name)
else:
raise Exception(f"Olap engine `{olap}` not supported.")
[docs]def readSparkDf(dataframe, olap="clickhouse"):
"""
Read data from a Spark DataFrame into a DataFrame.
>>> import fast_causal_inference
>>> fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
>>> spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
>>> df = fast_causal_inference.readSparkDf(spark_df)
"""
df = DataFrame()
table_name = DataFrame.createTableName()
if olap.lower() == "clickhouse":
dataframe_2_clickhouse(dataframe=dataframe, clickhouse_table_name=table_name)
return readClickHouse(table_name)
elif olap.lower() == "starrocks":
dataframe_2_starrocks(dataframe=dataframe, starrocks_table_name=table_name)
return readStarRocks(table_name)
else:
raise Exception(f"Olap engine `{olap}` not supported.")
[docs]def readCsv(csv_file_abs_path, olap="clickhouse"):
"""
Read data from a CSV file into a DataFrame.
>>> import fast_causal_inference
>>> df = fast_causal_inference.readCsv("/path/to/file.csv")
"""
df = DataFrame()
table_name = DataFrame.createTableName()
if olap.lower() == "clickhouse":
csv_2_clickhouse(
csv_file_abs_path=csv_file_abs_path, clickhouse_table_name=table_name
)
return readClickHouse(table_name)
elif olap.lower() == "starrocks":
csv_2_starrocks(
csv_file_abs_path=csv_file_abs_path, starrocks_table_name=table_name
)
return readStarRocks(table_name)
else:
raise Exception(f"Olap engine `{olap}` not supported.")