__all__ = [
    "readClickHouse",
    "readStarRocks",
    "readTdw",
    "readSparkDf",
    "readCsv",
    "DataFrame",
]
import random
from typing import Tuple
import pandas
import requests
import time
from fast_causal_inference.dataframe.format import outPutFormat
from fast_causal_inference.dataframe import ais_dataframe_pb2 as DfPb
import re
import base64
from google.protobuf import json_format
from google.protobuf.json_format import Parse
import copy
from fast_causal_inference.dataframe.df_base import (
    DfColumnNode,
    OlapEngineType,
    DfContext,
    DfColumnInternalNode,
    DfColumnLeafNode,
)
from fast_causal_inference.dataframe.functions import DfFnColWrapper
from fast_causal_inference.dataframe import functions as AisF
from fast_causal_inference.util.data_transformer import (
    clickhouse_2_csv,
    clickhouse_2_dataframe,
    dataframe_2_clickhouse,
    dataframe_2_starrocks,
    csv_2_clickhouse,
    csv_2_starrocks,
    starrocks_2_dataframe,
)
from fast_causal_inference.dataframe import regression as AisRegressionF
from fast_causal_inference.dataframe import statistics as AisStatisticsF
from fast_causal_inference.util import get_user, ClickHouseUtils
from fast_causal_inference.common import get_context
import fast_causal_inference.lib.tools as AisTools
"""
syntax = "proto3";
enum ColumnType {
  Unknown = 0;
  String = 1;
  Int = 2;
  Float = 3;
  Bool = 4;
  Date = 5;
  DateTime = 6;
  Time = 7;
  UUID = 8;
  Array = 9;
}
enum TaskType {
  TASK_TYPE_FILL_SCHEMA = 0;
}
message Column {
  string name = 1;
  string alias = 2;
  ColumnType type = 3;
}
message Limit {
  int64 limit = 1;
  int64 offset = 2;
}
enum SourceType {
  ClickHouse = 0;
}
message ClickHouseSource {
  string table_name = 1;
  string database = 2;
}
message Source {
  SourceType type = 1;
  ClickHouseSource clickhouse = 2;
}
message Order {
  Column column = 1;
  bool desc = 2;
}
message DataFrame {
  repeated Column columns = 1;
  repeated string filters = 2;
  repeated Column group_by = 3;
  repeated Order order_by = 4;
  Limit limit = 5;
  Source source = 6;
};
message DataFrameRequest {
  DataFrame df = 1;
  TaskType task_type = 2;
  string rtx = 3;
  int64 device_id = 4;
}
"""
def getSuperName(column):
    if isinstance(column, DfPb.Column):
        return column.alias if column.alias else column.name
    else:
        raise Exception("type error")
[docs]class DataFrame:
    """
    This class is used to create a DataFrame object.
    """
    def __init__(self, olap_engine=OlapEngineType.CLICKHOUSE):
        datasource = dict()
        PROJECT_CONF = get_context().project_conf
        for cell in PROJECT_CONF["datasource"]:
            datasource[cell["device_id"]] = cell
        self.device_id = None
        for device in datasource:
            if datasource[device].get(str(olap_engine) + "_database") is not None:
                self.device_id = device
                self.database = datasource[device][str(olap_engine) + "_database"]
                break
        if self.device_id is None:
            raise Exception(f"Unable to get any device of engine({olap_engine}).")
        self.data = DfPb.DataFrame()
        if olap_engine == OlapEngineType.STARROCKS:
            self.data.source.type = "StarRocks"
        self.url = (
            PROJECT_CONF["sqlgateway"]["url"]
            + PROJECT_CONF["sqlgateway"]["dataframe_path"]
        )
        self.rtx = get_user()
        self._ctx = DfContext(engine=olap_engine)
        self._select_list = []
        self._name_dict = {}
        self._need_agg = False
    def serialize(self):
        return self.data.SerializeToString()
    def serializeBase64(self):
        return base64.b64encode(self.serialize()).decode()
    def serializeJson(self):
        return json_format.MessageToJson(self.data)
    def deserialize(self, data):
        self.data.ParseFromString(data)
[docs]    def toPandas(self):
        """
        This function is used to convert the result of the dataframe to pandas.DataFrame
        """
        self._finalize()
        self.data.result = ""
        self.task_type = DfPb.TaskType.EXECUTE
        self.execute()
        res = outPutFormat(list(eval(self.data.result)))
        return res 
    def _is_column(self, column, name):
        if name is None:
            return False
        if isinstance(column, DfColumnNode):
            return column.alias == name or column.sql(ctx=self._ctx) == name
        raise Exception(f"type({type(column)}) is not DfColumn")
    def _get_col_name(self, column):
        if isinstance(column, DfColumnLeafNode):
            return column.alias if column.alias else column.sql(ctx=self._ctx)
        if isinstance(column, DfColumnInternalNode):
            return column.alias
        raise Exception(
            f"type({type(column)}) is not DfColumnLeafNode|DfColumnInternalNode"
        )
[docs]    def printSchema(self):
        """
        This function is used to print the schema of the dataframe
        """
        pd = pandas.DataFrame(columns=["name", "type"])
        for column in self._select_list:
            super_name = self._get_col_name(column)
            column_type = column.type
            # if columns type is empty, set Float64
            if column_type == None or column_type == "":
                column_type = "Float64"
            # TODO 如果 type 为空,应该先填充好
            pd = pd.append({"name": super_name, "type": column_type}, ignore_index=True)
        print(pd) 
    def __getitem__(self, key):
        for column in self._select_list:
            if column.alias == key or column.sql(self._ctx) == key:
                return AisF.col(column)
        raise Exception("column %s not found" % key)
    def debug(self):
        self._finalize()
        print(self.data.__str__())
        all_names = dict()
        for name in self._name_dict:
            all_names[name] = self._name_dict[name].sql(self._ctx)
        print(all_names)
    def _finalize(self):
        del self.data.columns[:]
        for col in self.data.group_by:
            self.data.columns.append(col)
        for col in self._select_list:
            self.data.columns.append(
                DfPb.Column(name=col.sql(ctx=self._ctx), alias=col.alias)
            )
    def __str__(self):
        self._finalize()
        res = self.toPandas()
        if res.shape[0] == 1 and res.shape[1] == 1:
            return res.values[0][0]
        return res.__str__()
    def __repr__(self):
        return self.__str__()
[docs]    def first(self):
        """
        This function is used to get the first row of the dataframe
        >>> df.first()
        """
        new_df = copy.deepcopy(self)
        new_df.data.limit.limit = 1
        return new_df 
[docs]    def head(self, n):
        """
        This function is used to get the first n rows of the dataframe
        >>> df.head(3)
        """
        new_df = copy.deepcopy(self)
        new_df.data.limit.limit = n
        return new_df 
[docs]    def take(self, n):
        """
        This function is used to get the first n rows of the dataframe
        >>> df.head(3)
        """
        new_df = copy.deepcopy(self)
        new_df.data.limit.limit = n
        return new_df 
[docs]    def show(self):
        """
        Prints the DataFrame, equivalent to print(dataframe).
        >>> df.head(3).show()
        """
        print(self.__str__()) 
    # 增加检验逻辑
    # 这里用的是 alias 的名字,需要替换回来, 下面的 groupBy, orderBy 也是一样
[docs]    def where(self, filter):
        """
        Filters rows using the given condition.
        >>> df.where("column1 > 1").show()
        """
        new_df = copy.deepcopy(self)
        if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", filter) != None:
            check_exist = False
            for column in new_df._select_list:
                super_name = self._get_col_name(column)
                if super_name in filter:
                    check_exist = True
                    break
            if not check_exist:
                raise Exception("the column of filter is not exist" % filter)
        new_df.data.filters.append(filter)
        return new_df 
[docs]    def filter(self, filter):
        """
        Alias for the 'where' function. Filters rows using the given condition.
        >>> df.filter("column1 > 1").show()
        """
        return self.where(filter) 
    @staticmethod
    def _expand_args(*args):
        expand_args = []
        for arg in args:
            if isinstance(arg, str):
                expand_args.append(arg)
            elif isinstance(arg, list):
                for sub_arg in arg:
                    expand_args.append(sub_arg)
        return expand_args
[docs]    def select(self, *args):
        """
        Selects specified columns from the DataFrame and returns a new DataFrame.
        >>> new_df = df.select('column1', 'column2')
        >>> new_df = df.select(['column1', 'column2'])
        """
        args = DataFrame._expand_args(*args)
        new_df = copy.deepcopy(self)
        new_df.checkColumn(*args)
        new_select_list = []
        for arg in args:
            new_select_list.append(new_df._expand_expr(arg))
        new_df._select_list = new_select_list
        return new_df 
[docs]    def drop(self, *args):
        """
        Drops specified columns from the DataFrame and returns a new DataFrame.
        >>> new_df = df.drop('column1', 'column2')
        >>> new_df = df.drop(['column1', 'column2'])
        """
        args = DataFrame._expand_args(*args)
        new_df = copy.deepcopy(self)
        new_df.checkColumn(*args)
        for i in range(len(new_df._select_list) - 1, -1, -1):
            name = self._get_col_name(new_df._select_list[i])
            is_exist = False
            for arg in args:
                if name == arg:
                    is_exist = True
                    break
            if is_exist:
                del new_df._select_list[i]
        return new_df 
[docs]    def withColumn(self, new_column, func):
        """
        This function adds a new column to the DataFrame.
        Example
        --------
        .. code-block:: python
            import fast_causal_inference.dataframe.functions as Fn
            import fast_causal_inference.dataframe.statistics as S
            df1 = df.select('x1','x2','x3', 'numerator')
            # Method 1: Select columns through df's index: df['col']
            df1 = df1.withColumn('new_col', Fn.sqrt(df1['numerator']))
            df1.show()
            # Method 2: Select columns directly through string: Fn.col('new_col')
            df1 = df1.withColumn('new_col2', Fn.pow('new_col', 2))
            df1.show()
            df1 = df1.withColumn('new_col3', Fn.col('new_col') * Fn.col('new_col'))
            df1.show()
            # Add constant
            df2 = df1.withColumn('c1', Fn.lit(1))
            df2 = df2.withColumn('c2', Fn.lit('1'))
            df2.show()
            # Nesting
            df2 = df1.withColumn('c1', Fn.pow(Fn.sqrt(Fn.sqrt(Fn.col('x1'))), 2))
            df2.show()
            # +-*/% operations
            df2 = df1.withColumn('c1', 22 + df1['x1'] / 2 + 2 / df1['x2'] * df1['x3'] % 2 - (df1['x2']))
            df2.show()
            df2 = df1.withColumn('c1', Fn.col('x1 + x2 * x3 + x3'))
            df2.show()
            # if
            df2 = df1.withColumn('cc1', 'if(x1 > 0, 1, -1)')
            df2.show()
            df2 = df1.withColumn('cc1', Fn.If('x1 > 0',1,-1))
            df2.show()
        """
        new_df = copy.deepcopy(self)
        if isinstance(func, str):
            func = new_df._add_space(func)
            for alias in new_df._name_dict:
                if alias is None or alias == "":
                    continue
                func = func.replace(
                    " " + alias + " ",
                    " " + new_df._name_dict[alias].sql(new_df._ctx) + " ",
                )
            func = func.replace(" ", "")
            new_df._select_list.append(DfColumnLeafNode(func, new_column))
            new_df._name_dict[new_column] = new_df._select_list[-1]
            new_df._name_dict[
                new_df._select_list[-1].sql(new_df._ctx)
            ] = new_df._select_list[-1]
        elif isinstance(func, DfFnColWrapper):
            new_df = new_df._apply_func(func.alias(new_column), keep_old_cols=True)
        return new_df 
    @classmethod
    def _add_space(cls, text):
        for ch in [",", "+", "-", "*", "/", ")", "(", "=", ">", "<", "!"]:
            text = text.replace(ch, " " + ch + " ")
        return " " + text + " "
    def _expand_expr(self, expr):
        if not isinstance(expr, str):
            raise Exception(f"Logical Error: expr(`{expr}`) is expected to be str.")
        expr = self._add_space(expr)
        origin_expr = expr
        for alias in self._name_dict:
            if alias is None or alias == "":
                continue
            expr = expr.replace(
                " " + alias + " ", " " + self._name_dict[alias].sql(self._ctx) + " "
            )
        expr = expr.replace(" ", "")
        origin_expr = origin_expr.replace(" ", "")
        alias = origin_expr if origin_expr != expr else None
        self._name_dict[origin_expr] = DfColumnLeafNode(expr, alias=alias)
        self._name_dict[expr] = DfColumnLeafNode(expr, alias=alias)
        return DfColumnLeafNode(expr, alias=alias)
[docs]    def withColumnRenamed(self, old_name, new_name):
        """
        Returns a new DataFrame by renaming an existing column.
        >>> df.withColumnRenamed("column1", "new_column1").show()
        """
        new_df = copy.deepcopy(self)
        new_df.checkColumn(old_name)
        for column in new_df._select_list:
            if column.alias == old_name or self._get_col_name(column) == old_name:
                column.alias = new_name
        return new_df 
[docs]    def describe(self, cols="*"):
        """
        Returns the summary statistics for the columns in the DataFrame.
        >>> df.describe()
        >>> df.describe(['x1','x2'])
        #       count       avg       std       min  quantile_0.25  quantile_0.5  quantile_0.75  quantile_0.90  quantile_0.99      max
        # x1  10000.0 -0.018434  0.987606 -3.740101       -0.68929     -0.028665       0.654210       1.274144       2.321097  3.80166
        # x2  10000.0  0.021976  1.986209 -8.893264       -1.28461      0.015400       1.357618       2.583523       4.725829  7.19662
        """
        self = self.materializedView()
        table_name = self.getTableName()
        return AisTools.describe(table_name, cols) 
[docs]    def sample(self, fraction):
        """
        This function samples a fraction of rows without replacement from the DataFrame.
        >>> df1 = df.sample(1000)
        >>> df1.count().show()
        >>> df2 = df.sample(0.5)
        >>> df2.count().show()
        """
        new_df = self.materializedView(is_temp=True, is_physical_table=True)
        row_num = int(new_df.count().__str__())
        if fraction <= 1 and fraction >= 0:
            pass
        elif fraction > 1:
            if fraction > row_num:
                raise Exception("fraction should be less than row number")
            fraction = fraction / row_num
        else:
            raise Exception("fraction should be in [0, 1] or > 1")
        new_df_name = DataFrame.createTableName()
        temp_df = new_df
        temp_df = temp_df.where(f"rand() / pow(2, 32) < {fraction}")
        return temp_df.materializedView(is_physical_table=True) 
[docs]    def split(self, test_size=0.5):
        """
        This function splits the DataFrame into two DataFrames.
        >>> df_train, df_test = df.split(0.5)
        >>> print(df_train.count())
        >>> print(df_test.count())
        """
        new_df = self.materializedView(is_temp=True)
        new_table_name = AisTools.data_split(new_df.getTableName(), test_size)
        return readClickHouse(new_table_name[0]), readClickHouse(new_table_name[1]) 
[docs]    def orderBy(self, *args):
        """
        Orders the DataFrame by the specified columns and returns a new DataFrame.
        >>> import fast_causal_inference.dataframe.functions as Fn
        >>> new_df = df.orderBy('column1', Fn.desc('column2'))
        >>> new_df = df.orderBy(['column1', 'column2'])
        """
        args = DataFrame._expand_args(*args)
        # clear self.data.order_by
        new_df = copy.deepcopy(self)
        del new_df.data.order_by[:]
        for arg in args:
            if isinstance(arg, str):
                order = DfPb.Order()
                order.column.name = arg
                new_df.data.order_by.append(order)
            elif isinstance(arg, DfPb.Order):
                new_df.data.order_by.append(arg)
        return new_df 
[docs]    def groupBy(self, *args):
        """
        Groups the DataFrame by the specified columns and returns a new DataFrame.Exception: If the DataFrame is already in `need_agg` state.
        >>> new_df = df.groupBy('column1', 'column2')
        >>> new_df = df.groupBy(['column1', 'column2'])
        """
        args = DataFrame._expand_args(*args)
        if self._need_agg:
            raise Exception(
                "Dataframe is already in `need_agg` state, cannot apply group by."
            )
        new_df = self._transform_to_nested()
        new_df._need_agg = True
        for arg in args:
            if arg not in self._name_dict:
                raise Exception(f"Unable to find column named `{arg}`.")
            column = DfPb.Column()
            column.name = arg
            new_df.data.group_by.append(column)
        return new_df 
    def agg(self, *args):
        # 支持如下两种语法
        # df.agg(avg("A"), sum("B").alias).show()
        # df.agg({"A": "avg", "B": "sum"}).show()
        new_df = copy.deepcopy(self)
        if len(args) == 0:
            raise Exception("Nothing to agg.")
        if len(args) == 1 and isinstance(args[0], dict):
            args = list(map(lambda key: getattr(AisF, args[0][key])(key), args[0]))
        if not all(map(lambda arg: isinstance(arg, DfFnColWrapper), args)):
            raise Exception(f"{args} is not List[DfFnColWrapper].")
        new_df = new_df._apply_func(*args)
        return new_df
        # aggregate functions
    def sum(self, column):
        return self._apply_func(AisF.sum(column))
    def avg(self, column):
        return self._apply_func(AisF.avg(column))
    def count(self, *, expr="*"):
        return self._apply_func(AisF.count(expr=expr))
    def max(self, column):
        return self._apply_func(AisF.max(column))
    def min(self, column):
        return self._apply_func(AisF.min(column))
    def any(self, column):
        return self._apply_func(AisF.any(column))
    def stddevPop(self, column):
        return self._apply_func(AisF.stddevPop(column))
    def stddevSamp(self, column):
        return self._apply_func(AisF.stddevSamp(column))
    def varPop(self, column):
        return self._apply_func(AisF.varPop(column))
    def varSamp(self, column):
        return self._apply_func(AisF.varSamp(column))
    def corr(self, x, y):
        return self._apply_func(AisF.corr(x, y))
    def covarPop(self, x, y):
        return self._apply_func(AisF.covarPop(x, y))
    def covarSamp(self, x, y):
        return self._apply_func(AisF.covarSamp(x, y))
    def anyLast(self, x, y):
        return self._apply_func(AisF.anyLast(x, y))
    def anyMin(self, x, y):
        return self._apply_func(AisF.anyMin(x, y))
    def anyMax(self, x, y):
        return self._apply_func(AisF.anyMax(x, y))
    def kolmogorov_smirnov_test(self, x, y):
        return self._apply_func(AisStatisticsF.kolmogorov_smirnov_test(x, y))
    def student_ttest(self, x, y):
        return self._apply_func(AisStatisticsF.student_ttest(x, y))
    def welch_ttest(self, x, y):
        return self._apply_func(AisStatisticsF.welch_ttest(x, y))
    def mean_z_test(
        self,
        sample_data,
        sample_index,
        population_variance_x,
        population_variance_y,
        confidence_level,
    ):
        return self._apply_func(
            AisStatisticsF.mean_z_test(
                sample_data,
                sample_index,
                population_variance_x,
                population_variance_y,
                confidence_level,
            )
        )
    def quantile(self, x, level=None):
        if level is None:
            raise Exception("param `level` is not set")
        return self._apply_func(AisF.quantile(x, level=level))
    # all in sql functions
    def delta_method(self, column, std="True"):
        return self._apply_func(AisStatisticsF.delta_method(expr=column, std=std))
    def ttest_1samp(self, Y, alternative="two-sided", mu=0, X=""):
        return self._apply_func(AisStatisticsF.ttest_1samp(Y, alternative, mu, X))
    def ttest_2samp(self, Y, index, alternative="two-sided", X="",pse=''):
        return self._apply_func(
            AisStatisticsF.ttest_2samp(Y, index, alternative=alternative, X=X,pse=pse)
        )
    def xexpt_ttest_2samp(
        self,
        numerator,
        denominator,
        index,
        uin,
        metric_type="avg",
        group_buckets="[1:1]",
        alpha=0.05,
        MDE=0.005,
        power=0.8,
        X="",
    ):
        return self._apply_func(
            AisStatisticsF.xexpt_ttest_2samp(
                numerator,
                denominator,
                index,
                uin,
                metric_type,
                group_buckets,
                alpha,
                MDE,
                power,
                X,
            )
        )
    def mann_whitney_utest(
        self,
        sample_data,
        sample_index,
        alternative="two-sided",
        continuity_correction=1,
    ):
        return self._apply_func(
            AisStatisticsF.mann_whitney_utest(
                sample_data, sample_index, alternative, continuity_correction
            )
        )
    def srm(self, x, groupby, ratio="[1,1]"):
        return self._apply_func(AisStatisticsF.srm(x, groupby, ratio))
    def ols(self, column, use_bias=True):
        return self._apply_func(AisRegressionF.ols(column, use_bias))
    def wls(self, column, weight, use_bias=True):
        return self._apply_func(AisRegressionF.wls(column, weight, use_bias))
    def stochastic_linear_regression(
        self, expr, learning_rate=0.00001, l1=0.1, batch_size=15, method="SGD"
    ):
        return self._apply_func(
            AisRegressionF.stochastic_linear_regression(
                expr, learning_rate, l1, batch_size, method
            )
        )
    def stochastic_logistic_regression(
        self, expr, learning_rate=0.00001, l1=0.1, batch_size=15, method="SGD"
    ):
        return self._apply_func(
            AisRegressionF.stochastic_logistic_regression(
                expr, learning_rate, l1, batch_size, method
            )
        )
    def matrix_multiplication(self, *args, std=False, invert=False):
        return self._apply_func(
            AisStatisticsF.matrix_multiplication(*args, std=std, invert=invert)
        )
    def did(self, Y, treatment, time, *X):
        return self._apply_func(AisRegressionF.did(Y, treatment, time, *X))
    def iv_regression(self, formula):
        return self._apply_func(AisRegressionF.iv_regression(formula))
    def boot_strap(self, func, sample_num, bs_num):
        return self._apply_func(AisStatisticsF.boot_strap(func, sample_num, bs_num))
    def permutation(self, func, permutation_num, *col):
        return self._apply_func(AisStatisticsF.permutation(func, permutation_num, *col))
    def checkColumn(self, *args, throw_exception=True):
        for arg in args:
            if arg not in self._name_dict:
                if throw_exception == True:
                    raise Exception("column %s not found" % arg)
                else:
                    return False
        return True
    # 优先提取 alias
    def getAliasOrName(self):
        alias_or_name = []
        for column in self._select_list:
            alias_or_name.append(self._get_col_name(column))
        return alias_or_name
    def _find_column(self, name):
        if self._name_dict.get(name) is not None:
            # print(f"find column `{name}`: {self._name_dict.get(name).sql(self._ctx)}")
            return self._name_dict.get(name)
        for column in self._select_list:
            col = column.find_column(name)
            if col is not None:
                # print(f"find column `{name}`: {col.sql(self._ctx)}")
                return col
        # print(f"cannot find column `{name}`, using self.")
        return DfColumnLeafNode(name)
    def _unwrap(self, fn_wrapper: DfFnColWrapper):
        if not isinstance(fn_wrapper, DfFnColWrapper):
            raise Exception(f"func({type(DfFnColWrapper)}) should be DfFnColWrapper!")
        fn, params, cols = fn_wrapper.fn, fn_wrapper.params, fn_wrapper.columns
        alias = fn.alias
        columns = []
        for col in cols:
            if isinstance(col, int) or isinstance(col, float) or isinstance(col, bool):
                from_col = DfColumnLeafNode(str(col))
            elif isinstance(col, str):
                from_col = copy.deepcopy(self._find_column(col))
                self._name_dict[col] = from_col
            elif isinstance(col, DfColumnNode):
                from_col = copy.deepcopy(col)
                if col.alias is not None:
                    self._name_dict[col.alias] = from_col
            elif isinstance(col, DfFnColWrapper):
                from_col = self._unwrap(col)
                if from_col.alias is not None:
                    self._name_dict[from_col.alias] = from_col
            else:
                raise Exception(
                    f"type of col({type(col)}) is neither str nor DfColumn."
                )
            if not isinstance(from_col, DfColumnNode):
                # we can only apply function on DfColumn but no others
                raise Exception(f"Logical Error: {type(from_col)} is not DfColumn.")
            self._name_dict[from_col.sql(self._ctx)] = from_col
            columns.append(from_col)
        new_col = DfColumnInternalNode(fn, params, columns, alias)
        if alias:
            self._name_dict[alias] = new_col
        self._name_dict[new_col.sql(self._ctx)] = new_col
        return new_col
    def _apply_func(self, *fn_wrappers: Tuple[DfFnColWrapper], keep_old_cols=False):
        if any(map(lambda fn_wrapper: fn_wrapper.has_agg_func(), fn_wrappers)):
            if not self._need_agg:
                new_df = self._transform_to_nested()
            else:
                new_df = copy.deepcopy(self)
            new_df._need_agg = False
        else:
            new_df = copy.deepcopy(self)
        if not keep_old_cols:
            new_df._select_list = []
        for fn_wrapper in fn_wrappers:
            if not isinstance(fn_wrapper, DfFnColWrapper):
                raise Exception("fn_wrapper should be a DfFnColWrapper object!")
            new_col: DfColumnNode = new_df._unwrap(fn_wrapper)
            new_df._select_list.append(new_col)
        return new_df
    def _set_cte(self, cte):
        self.data.cte = cte
    def _transform_to_nested(self):
        new_df = copy.deepcopy(self)
        subquery = new_df.getExecutedSql()
        del new_df.data.columns[:]
        del new_df.data.filters[:]
        del new_df.data.group_by[:]
        del new_df.data.order_by[:]
        new_df.data.limit.limit = 0
        del new_df._select_list[:]
        new_df.data.cte = ""
        for col_name in new_df._name_dict:
            new_df._name_dict[col_name] = DfColumnLeafNode(col_name)
        if new_df.data.source.type == DfPb.SourceType.ClickHouse:
            new_df.data.source.clickhouse.table_name = subquery
            new_df.data.source.clickhouse.database = "Nested"
        elif new_df.data.source.type == DfPb.SourceType.StarRocks:
            new_df.data.source.starrocks.table_name = subquery
            new_df.data.source.starrocks.database = "Nested"
        else:
            raise Exception("not support source type")
        return new_df
[docs]    def union(self, df):
        """
        This function is used to union two DataFrames.
        The two DataFrames must have the same number of columns, and the columns must have the same names and order.
        >>> df1 = df1.union(df2)
        """
        if len(self._select_list) != len(df._select_list):
            raise Exception("the length of select list not match")
        for i in range(len(self._select_list)):
            if self._select_list[i].alias != df._select_list[i].alias:
                raise Exception("the column name of select list not match")
            if self._select_list[i].alias == None and self._select_list[i].sql(
                self._ctx
            ) != df._select_list[i].sql(df._ctx):
                raise Exception("the column name of select list not match")
        new_df = copy.deepcopy(self)
        subquery1 = self.getExecutedSql()
        subquery2 = df.getExecutedSql()
        if new_df.data.source.type == DfPb.SourceType.ClickHouse:
            new_df.data.source.clickhouse.table_name = f"({subquery1}) union all ({subquery2})"
            new_df.data.source.clickhouse.database = "Nested"
        elif new_df.data.source.type == DfPb.SourceType.StarRocks:
            new_df.data.source.starrocks.table_name = f"({subquery1}) union all ({subquery2})"
            new_df.data.source.starrocks.database = "Nested"
        new_df._select_list = []
        for col in self._select_list:
            new_df._select_list.append(new_df._expand_expr(self._get_col_name(col)))
        return new_df 
    # parse executed sql, remove limit
    def getExecutedSql(self):
        self.__str__()
        sql = self.data.execute_sql
        # print(sql)
        sql = re.sub(r"limit\s+\d+\s*$", "", sql, flags=re.IGNORECASE)
        return sql
    # is_temp = True: 每天凌晨 2 点自动删除
    def materializedView(
        self,
        is_physical_table=False,
        is_distributed_create=True,
        is_temp=False,
        table_name=None,
    ):
        materialized_sql = self.getExecutedSql()
        if table_name is not None:
            new_df_name = table_name
        else:
            new_df_name = DataFrame.createTableName(is_temp)
        ClickHouseUtils.clickhouse_create_view_v2(
            table_name=new_df_name,
            select_statement=materialized_sql,
            origin_table_name=self.getTableName(),
            is_physical_table=is_physical_table,
            is_distributed_create=is_distributed_create,
        )
        return readClickHouse(new_df_name)
    def createDfRequest(self):
        df_req = DfPb.DataFrameRequest()
        df_req.df.CopyFrom(self.data)
        df_req.task_type = self.task_type
        df_req.rtx = self.rtx
        df_req.device_id = self.device_id
        df_req.database = self.database
        return df_req
    def createDataFrameRequestBase64(self):
        df_req = self.createDfRequest()
        return base64.b64encode(df_req.SerializeToString()).decode()
    def createDfRequestJson(self):
        df_req = self.createDfRequest()
        return json_format.MessageToJson(df_req)
    def createCurlRequest(self):
        return (
            'curl  -H "Content-Type: application/json" -X POST -d \''
            + self.createDfRequestJson()
            + "' "
            + self.url
            + "/json"
        )
    def execute(self, retry_times=1):
        logger = get_context().logger
        while retry_times > 0:
            try:
                json_body = self.createDfRequestJson()
                logger.debug("url= " + self.url + ",data= " + json_body)
                resp = requests.post(
                    self.url,
                    data=json_body.encode("utf-8"),
                    headers={
                        "Content-Type": "application/json",
                        "Accept": "application/json",
                    },
                )
                logger.debug("response=" + resp.text)
                df_resp = Parse(resp.text, DfPb.DataFrameResponse())
                if df_resp.status == DfPb.RetStatus.FAIL:
                    print("Error: ")
                    print(df_resp.msg)
                elif df_resp.status == DfPb.RetStatus.SUCC:
                    self.data = df_resp.df
                    if not self._select_list:
                        self._select_list = [
                            DfColumnLeafNode(
                                column_name=col.name, alias=col.alias, type=col.type
                            )
                            for col in df_resp.df.columns
                        ]
                        for col in self._select_list:
                            self._name_dict[col.sql(self._ctx)] = col
                return
            except Exception as e:
                time.sleep(1)
                retry_times -= 1
                if retry_times == 0:
                    raise e
    @staticmethod
    def createTableName(is_temp=False):
        table_name = "df_table_"
        if is_temp:
            table_name += "temp_"
        table_name += time.strftime("%Y%m%d%H%M%S", time.localtime()) + "_"
        table_name += str(random.randint(100000, 999999))
        return table_name
    def fill_column_info(self):
        self.task_type = DfPb.TaskType.FILL_SCHEMA
        self.execute()
    # 如果有嵌套子查询,获取最内层的表名
    def getTableName(self):
        return self.data.source.clickhouse.table_name
[docs]    def toCsv(self, csv_file_abs_path):
        """
        Convert the data from ClickHouse table to a CSV file.
        >>> df.toCsv("/path/to/output.csv")
        """
        clickhouse_table_name = self.data.source.clickhouse.table_name
        clickhouse_2_csv(clickhouse_table_name, csv_file_abs_path) 
[docs]    def toTdw(
        self,
        tdw_database,
        tdw_table,
        tdw_user=None,
        tdw_passward=None,
        group="tl",
        is_drop_table=False,
        overwrite=True,
        priPart=None,
    ):
        """
        ClickHouse table >> TDW-thive table.
        Parameters
        ----------
            :tdw_database (str): The name of the TDW database.
            :tdw_table (str): The name of the TDW table.
            :tdw_user (str, optional): The username for TDW. Default is None.
            :tdw_passward (str, optional): The password for TDW. Default is None.
            :group (str, optional): The group for TDW. Default is 'tl'.
            :is_drop_table (bool, optional): Whether to drop the existing TDW table. Default is False.
            :overwrite (bool, optional): Whether to overwrite the existing data in the TDW table. Default is True.
            :priPart (str, optional): The primary partition for the TDW table. Default is None.
        Example
        ----------
        >>> df.toTdw("tdw_database", "tdw_table", group='tl')
        >>> df.toTdw("tdw_database", "tdw_table", group='tl',priPart=['p_20220222'])
        """
        view = self.materializedView()
        from fast_causal_inference.util.clickhouse_utils import ClickHouseUtils
        ClickHouseUtils.clickhouse_2_tdw_v2(
            view.getTableName(),
            tdw_database,
            tdw_table,
            tdw_user,
            tdw_passward,
            group,
            is_drop_table,
            overwrite,
            priPart,
        ) 
[docs]    def toSparkDf(self):
        """
        ClickHouse table >> spark dataframe.
        Example
        ----------
        >>> import fast_causal_inference
        >>> fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
        >>> spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
        >>> df = fast_causal_inference.readClickHouse('test_data_small')
        >>> spark_df = df.toSparkDf()
        """
        session = getSparkSession()
        if self._ctx.engine() == OlapEngineType.CLICKHOUSE:
            return clickhouse_2_dataframe(
                session, self.data.source.clickhouse.table_name
            )
        elif self._ctx.engine() == OlapEngineType.STARROCKS:
            return starrocks_2_dataframe(session, self.data.source.starrocks.table_name)
        else:
            raise Exception(f"Olap engine `{self._ctx.engine()}` not supported.") 
[docs]    def toClickHouse(self, clickhouse_table_name):
        """
        ClickHouse table >> ClickHouse table.
        Example
        ----------
        >>> df.toClickHouse("new_table")
        """
        clickhouse_table_name = clickhouse_table_name
        self.materializedView(is_physical_table=True, table_name=clickhouse_table_name)  
def getSparkSession():
    from fast_causal_inference import get_context
    if get_context().spark_session == None:
        print("Spark Session is None, pless init")
        return
    return get_context().spark_session
# 出入仓
[docs]def readClickHouse(table_name):
    """
    Read data from a ClickHouse table into a DataFrame.
    >>> import fast_causal_inference
    >>> df = fast_causal_inference.readClickHouse("test_data_small")
    """
    df = DataFrame()
    df.data.source.clickhouse.table_name = table_name
    df.data.source.clickhouse.database = df.database
    df.fill_column_info()
    return df 
[docs]def readStarRocks(table_name):
    """
    Read data from a StarRocks table into a DataFrame.
    >>> import fast_causal_inference
    >>> df = fast_causal_inference.readStarRocks("test_data_small")
    """
    df = DataFrame(olap_engine=OlapEngineType.STARROCKS)
    df.data.source.starrocks.table_name = table_name
    df.data.source.starrocks.database = df.database
    df.fill_column_info()
    return df 
def readTdw(
    db,
    table,
    group="tl",
    tdw_user=None,
    tdw_passward=None,
    priParts=None,
    str_replace="-1",
    numeric_replace=0,
    olap="clickhouse",
):
    """
    Read data from a TDW-thive table into a DataFrame.
    Parameters
    ----------
    :param db: The name of the TDW database.
    :param table: The name of the TDW table.
    :param group: The group for TDW. Default is 'tl'.
    :param user: The username for TDW. Default is None.
    :param passwd: The password for TDW. Default is None.
    :param str_replace: The replacement for string NA values. Default is "-1".
    :param numeric_replace: The replacement for numeric NA values. Default is 0.
    Example
    -------
    ::
        # thive-第一种方式 tdw thive > clickhouse
        import fast_causal_inference
        fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
        spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
        df = fast_causal_inference.readTdw("db_name", "table_name",  group="tl") # 读普通表
        df = fast_causal_inference.readTdw("db_name", "table_name",  group="tl", priPart=['p_20220222']) # 读分区表
        # thive-第二种方式 tdw thive > spark > clickhouse
        import fast_causal_inference
        fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
        spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
        from pytoolkit import TDWSQLProvider
        tdw = TDWSQLProvider(spark, group="tl",db="db_name")
        spark_df = tdw.table(tblName="allinsql_test_data_small")# 读普通表
        spark_df = tdw.table(tblName="allinsql_test_data_small", priPart=['p_20220222'])# 读分区表
        spark_df.count()
        df_ch = fast_causal_inference.readSparkDf(spark_df)
        
    """
    session = get_context().spark_session
    from pytoolkit import TDWSQLProvider
    tdw = TDWSQLProvider(
        session, group=group, db=db, user=tdw_user, passwd=tdw_passward
    )
    df_new = tdw.table(tblName=table, priParts=priParts)
    df_new = AisTools.preprocess_na(df_new, str_replace, numeric_replace)
    df = DataFrame()
    table_name = DataFrame.createTableName()
    if olap.lower() == "clickhouse":
        dataframe_2_clickhouse(dataframe=df_new, clickhouse_table_name=table_name)
        return readClickHouse(table_name)
    elif olap.lower() == "starrocks":
        dataframe_2_starrocks(dataframe=df_new, starrocks_table_name=table_name)
        return readStarRocks(table_name)
    else:
        raise Exception(f"Olap engine `{olap}` not supported.")
[docs]def readSparkDf(dataframe, olap="clickhouse"):
    """
    Read data from a Spark DataFrame into a DataFrame.
    >>> import fast_causal_inference
    >>> fast_causal_inference.set_default(tenant_id="",tenant_secret_key="")
    >>> spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html
    >>> df = fast_causal_inference.readSparkDf(spark_df)
    """
    df = DataFrame()
    table_name = DataFrame.createTableName()
    if olap.lower() == "clickhouse":
        dataframe_2_clickhouse(dataframe=dataframe, clickhouse_table_name=table_name)
        return readClickHouse(table_name)
    elif olap.lower() == "starrocks":
        dataframe_2_starrocks(dataframe=dataframe, starrocks_table_name=table_name)
        return readStarRocks(table_name)
    else:
        raise Exception(f"Olap engine `{olap}` not supported.") 
[docs]def readCsv(csv_file_abs_path, olap="clickhouse"):
    """
    Read data from a CSV file into a DataFrame.
    >>> import fast_causal_inference
    >>> df = fast_causal_inference.readCsv("/path/to/file.csv")
    """
    df = DataFrame()
    table_name = DataFrame.createTableName()
    if olap.lower() == "clickhouse":
        csv_2_clickhouse(
            csv_file_abs_path=csv_file_abs_path, clickhouse_table_name=table_name
        )
        return readClickHouse(table_name)
    elif olap.lower() == "starrocks":
        csv_2_starrocks(
            csv_file_abs_path=csv_file_abs_path, starrocks_table_name=table_name
        )
        return readStarRocks(table_name)
    else:
        raise Exception(f"Olap engine `{olap}` not supported.")