Source code for dataframe.dataframe

__all__ = [
    "readClickHouse",
    "readStarRocks",
    "readTdw",
    "readSparkDf",
    "readCsv",
    "DataFrame",
]

import random
from typing import Tuple

import pandas
import requests
import time

from fast_causal_inference.dataframe.format import outPutFormat
from fast_causal_inference.dataframe import ais_dataframe_pb2 as DfPb
import re
import base64
from google.protobuf import json_format
from google.protobuf.json_format import Parse
import copy
from fast_causal_inference.dataframe.df_base import (
    DfColumnNode,
    OlapEngineType,
    DfContext,
    DfColumnInternalNode,
    DfColumnLeafNode,
)
from fast_causal_inference.dataframe.functions import DfFnColWrapper
from fast_causal_inference.dataframe import functions as AisF
from fast_causal_inference.util.data_transformer import (
    clickhouse_2_csv,
    clickhouse_2_dataframe,
    dataframe_2_clickhouse,
    dataframe_2_starrocks,
    csv_2_clickhouse,
    csv_2_starrocks,
    starrocks_2_dataframe,
)
from fast_causal_inference.dataframe import regression as AisRegressionF
from fast_causal_inference.dataframe import statistics as AisStatisticsF
from fast_causal_inference.util import get_user, ClickHouseUtils
from fast_causal_inference.common import get_context
import fast_causal_inference.lib.tools as AisTools


"""
syntax = "proto3";

enum ColumnType {
  Unknown = 0;
  String = 1;
  Int = 2;
  Float = 3;
  Bool = 4;
  Date = 5;
  DateTime = 6;
  Time = 7;
  UUID = 8;
  Array = 9;
}

enum TaskType {
  TASK_TYPE_FILL_SCHEMA = 0;
}

message Column {
  string name = 1;
  string alias = 2;
  ColumnType type = 3;
}

message Limit {
  int64 limit = 1;
  int64 offset = 2;
}

enum SourceType {
  ClickHouse = 0;
}

message ClickHouseSource {
  string table_name = 1;
  string database = 2;
}

message Source {
  SourceType type = 1;
  ClickHouseSource clickhouse = 2;
}

message Order {
  Column column = 1;
  bool desc = 2;
}

message DataFrame {
  repeated Column columns = 1;
  repeated string filters = 2;
  repeated Column group_by = 3;
  repeated Order order_by = 4;
  Limit limit = 5;
  Source source = 6;
};

message DataFrameRequest {
  DataFrame df = 1;
  TaskType task_type = 2;
  string rtx = 3;
  int64 device_id = 4;
}

"""


def getSuperName(column):
    if isinstance(column, DfPb.Column):
        return column.alias if column.alias else column.name
    else:
        raise Exception("type error")


[docs]class DataFrame: """ This class is used to create a DataFrame object. """ def __init__(self, olap_engine=OlapEngineType.CLICKHOUSE): datasource = dict() PROJECT_CONF = get_context().project_conf for cell in PROJECT_CONF["datasource"]: datasource[cell["device_id"]] = cell self.device_id = None for device in datasource: if datasource[device].get(str(olap_engine) + "_database") is not None: self.device_id = device self.database = datasource[device][str(olap_engine) + "_database"] break if self.device_id is None: raise Exception(f"Unable to get any device of engine({olap_engine}).") self.data = DfPb.DataFrame() if olap_engine == OlapEngineType.STARROCKS: self.data.source.type = "StarRocks" self.url = ( PROJECT_CONF["sqlgateway"]["url"] + PROJECT_CONF["sqlgateway"]["dataframe_path"] ) self.rtx = get_user() self._ctx = DfContext(engine=olap_engine) self._select_list = [] self._name_dict = {} self._need_agg = False def serialize(self): return self.data.SerializeToString() def serializeBase64(self): return base64.b64encode(self.serialize()).decode() def serializeJson(self): return json_format.MessageToJson(self.data) def deserialize(self, data): self.data.ParseFromString(data)
[docs] def toPandas(self): """ This function is used to convert the result of the dataframe to pandas.DataFrame """ self._finalize() self.data.result = "" self.task_type = DfPb.TaskType.EXECUTE self.execute() res = outPutFormat(list(eval(self.data.result))) return res
def _is_column(self, column, name): if name is None: return False if isinstance(column, DfColumnNode): return column.alias == name or column.sql(ctx=self._ctx) == name raise Exception(f"type({type(column)}) is not DfColumn") def _get_col_name(self, column): if isinstance(column, DfColumnLeafNode): return column.alias if column.alias else column.sql(ctx=self._ctx) if isinstance(column, DfColumnInternalNode): return column.alias raise Exception( f"type({type(column)}) is not DfColumnLeafNode|DfColumnInternalNode" )
[docs] def printSchema(self): """ This function is used to print the schema of the dataframe """ pd = pandas.DataFrame(columns=["name", "type"]) for column in self._select_list: super_name = self._get_col_name(column) column_type = column.type # if columns type is empty, set Float64 if column_type == None or column_type == "": column_type = "Float64" # TODO 如果 type 为空,应该先填充好 pd = pd.append({"name": super_name, "type": column_type}, ignore_index=True) print(pd)
def __getitem__(self, key): for column in self._select_list: if column.alias == key or column.sql(self._ctx) == key: return AisF.col(column) raise Exception("column %s not found" % key) def debug(self): self._finalize() print(self.data.__str__()) all_names = dict() for name in self._name_dict: all_names[name] = self._name_dict[name].sql(self._ctx) print(all_names) def _finalize(self): del self.data.columns[:] for col in self.data.group_by: self.data.columns.append(col) for col in self._select_list: self.data.columns.append( DfPb.Column(name=col.sql(ctx=self._ctx), alias=col.alias) ) def __str__(self): self._finalize() res = self.toPandas() if res.shape[0] == 1 and res.shape[1] == 1: return res.values[0][0] return res.__str__() def __repr__(self): return self.__str__()
[docs] def first(self): """ This function is used to get the first row of the dataframe >>> df.first() """ new_df = copy.deepcopy(self) new_df.data.limit.limit = 1 return new_df
[docs] def head(self, n): """ This function is used to get the first n rows of the dataframe >>> df.head(3) """ new_df = copy.deepcopy(self) new_df.data.limit.limit = n return new_df
[docs] def take(self, n): """ This function is used to get the first n rows of the dataframe >>> df.head(3) """ new_df = copy.deepcopy(self) new_df.data.limit.limit = n return new_df
[docs] def show(self): """ Prints the DataFrame, equivalent to print(dataframe). >>> df.head(3).show() """ print(self.__str__())
# 增加检验逻辑 # 这里用的是 alias 的名字,需要替换回来, 下面的 groupBy, orderBy 也是一样
[docs] def where(self, filter): """ Filters rows using the given condition. >>> df.where("column1 > 1").show() """ new_df = copy.deepcopy(self) if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", filter) != None: check_exist = False for column in new_df._select_list: super_name = self._get_col_name(column) if super_name in filter: check_exist = True break if not check_exist: raise Exception("the column of filter is not exist" % filter) new_df.data.filters.append(filter) return new_df
[docs] def filter(self, filter): """ Alias for the 'where' function. Filters rows using the given condition. >>> df.filter("column1 > 1").show() """ return self.where(filter)
@staticmethod def _expand_args(*args): expand_args = [] for arg in args: if isinstance(arg, str): expand_args.append(arg) elif isinstance(arg, list): for sub_arg in arg: expand_args.append(sub_arg) return expand_args
[docs] def select(self, *args): """ Selects specified columns from the DataFrame and returns a new DataFrame. >>> new_df = df.select('column1', 'column2') >>> new_df = df.select(['column1', 'column2']) """ args = DataFrame._expand_args(*args) new_df = copy.deepcopy(self) new_df.checkColumn(*args) new_select_list = [] for arg in args: new_select_list.append(new_df._expand_expr(arg)) new_df._select_list = new_select_list return new_df
[docs] def drop(self, *args): """ Drops specified columns from the DataFrame and returns a new DataFrame. >>> new_df = df.drop('column1', 'column2') >>> new_df = df.drop(['column1', 'column2']) """ args = DataFrame._expand_args(*args) new_df = copy.deepcopy(self) new_df.checkColumn(*args) for i in range(len(new_df._select_list) - 1, -1, -1): name = self._get_col_name(new_df._select_list[i]) is_exist = False for arg in args: if name == arg: is_exist = True break if is_exist: del new_df._select_list[i] return new_df
[docs] def withColumn(self, new_column, func): """ This function adds a new column to the DataFrame. Example -------- .. code-block:: python import fast_causal_inference.dataframe.functions as Fn import fast_causal_inference.dataframe.statistics as S df1 = df.select('x1','x2','x3', 'numerator') # Method 1: Select columns through df's index: df['col'] df1 = df1.withColumn('new_col', Fn.sqrt(df1['numerator'])) df1.show() # Method 2: Select columns directly through string: Fn.col('new_col') df1 = df1.withColumn('new_col2', Fn.pow('new_col', 2)) df1.show() df1 = df1.withColumn('new_col3', Fn.col('new_col') * Fn.col('new_col')) df1.show() # Add constant df2 = df1.withColumn('c1', Fn.lit(1)) df2 = df2.withColumn('c2', Fn.lit('1')) df2.show() # Nesting df2 = df1.withColumn('c1', Fn.pow(Fn.sqrt(Fn.sqrt(Fn.col('x1'))), 2)) df2.show() # +-*/% operations df2 = df1.withColumn('c1', 22 + df1['x1'] / 2 + 2 / df1['x2'] * df1['x3'] % 2 - (df1['x2'])) df2.show() df2 = df1.withColumn('c1', Fn.col('x1 + x2 * x3 + x3')) df2.show() # if df2 = df1.withColumn('cc1', 'if(x1 > 0, 1, -1)') df2.show() df2 = df1.withColumn('cc1', Fn.If('x1 > 0',1,-1)) df2.show() """ new_df = copy.deepcopy(self) if isinstance(func, str): func = new_df._add_space(func) for alias in new_df._name_dict: if alias is None or alias == "": continue func = func.replace( " " + alias + " ", " " + new_df._name_dict[alias].sql(new_df._ctx) + " ", ) func = func.replace(" ", "") new_df._select_list.append(DfColumnLeafNode(func, new_column)) new_df._name_dict[new_column] = new_df._select_list[-1] new_df._name_dict[ new_df._select_list[-1].sql(new_df._ctx) ] = new_df._select_list[-1] elif isinstance(func, DfFnColWrapper): new_df = new_df._apply_func(func.alias(new_column), keep_old_cols=True) return new_df
@classmethod def _add_space(cls, text): for ch in [",", "+", "-", "*", "/", ")", "(", "=", ">", "<", "!"]: text = text.replace(ch, " " + ch + " ") return " " + text + " " def _expand_expr(self, expr): if not isinstance(expr, str): raise Exception(f"Logical Error: expr(`{expr}`) is expected to be str.") expr = self._add_space(expr) origin_expr = expr for alias in self._name_dict: if alias is None or alias == "": continue expr = expr.replace( " " + alias + " ", " " + self._name_dict[alias].sql(self._ctx) + " " ) expr = expr.replace(" ", "") origin_expr = origin_expr.replace(" ", "") alias = origin_expr if origin_expr != expr else None self._name_dict[origin_expr] = DfColumnLeafNode(expr, alias=alias) self._name_dict[expr] = DfColumnLeafNode(expr, alias=alias) return DfColumnLeafNode(expr, alias=alias)
[docs] def withColumnRenamed(self, old_name, new_name): """ Returns a new DataFrame by renaming an existing column. >>> df.withColumnRenamed("column1", "new_column1").show() """ new_df = copy.deepcopy(self) new_df.checkColumn(old_name) for column in new_df._select_list: if column.alias == old_name or self._get_col_name(column) == old_name: column.alias = new_name return new_df
[docs] def describe(self, cols="*"): """ Returns the summary statistics for the columns in the DataFrame. >>> df.describe() >>> df.describe(['x1','x2']) # count avg std min quantile_0.25 quantile_0.5 quantile_0.75 quantile_0.90 quantile_0.99 max # x1 10000.0 -0.018434 0.987606 -3.740101 -0.68929 -0.028665 0.654210 1.274144 2.321097 3.80166 # x2 10000.0 0.021976 1.986209 -8.893264 -1.28461 0.015400 1.357618 2.583523 4.725829 7.19662 """ self = self.materializedView() table_name = self.getTableName() return AisTools.describe(table_name, cols)
[docs] def sample(self, fraction): """ This function samples a fraction of rows without replacement from the DataFrame. >>> df1 = df.sample(1000) >>> df1.count().show() >>> df2 = df.sample(0.5) >>> df2.count().show() """ new_df = self.materializedView(is_temp=True, is_physical_table=True) row_num = int(new_df.count().__str__()) if fraction <= 1 and fraction >= 0: pass elif fraction > 1: if fraction > row_num: raise Exception("fraction should be less than row number") fraction = fraction / row_num else: raise Exception("fraction should be in [0, 1] or > 1") new_df_name = DataFrame.createTableName() temp_df = new_df temp_df = temp_df.where(f"rand() / pow(2, 32) < {fraction}") return temp_df.materializedView(is_physical_table=True)
[docs] def split(self, test_size=0.5): """ This function splits the DataFrame into two DataFrames. >>> df_train, df_test = df.split(0.5) >>> print(df_train.count()) >>> print(df_test.count()) """ new_df = self.materializedView(is_temp=True) new_table_name = AisTools.data_split(new_df.getTableName(), test_size) return readClickHouse(new_table_name[0]), readClickHouse(new_table_name[1])
[docs] def orderBy(self, *args): """ Orders the DataFrame by the specified columns and returns a new DataFrame. >>> import fast_causal_inference.dataframe.functions as Fn >>> new_df = df.orderBy('column1', Fn.desc('column2')) >>> new_df = df.orderBy(['column1', 'column2']) """ args = DataFrame._expand_args(*args) # clear self.data.order_by new_df = copy.deepcopy(self) del new_df.data.order_by[:] for arg in args: if isinstance(arg, str): order = DfPb.Order() order.column.name = arg new_df.data.order_by.append(order) elif isinstance(arg, DfPb.Order): new_df.data.order_by.append(arg) return new_df
[docs] def groupBy(self, *args): """ Groups the DataFrame by the specified columns and returns a new DataFrame.Exception: If the DataFrame is already in `need_agg` state. >>> new_df = df.groupBy('column1', 'column2') >>> new_df = df.groupBy(['column1', 'column2']) """ args = DataFrame._expand_args(*args) if self._need_agg: raise Exception( "Dataframe is already in `need_agg` state, cannot apply group by." ) new_df = self._transform_to_nested() new_df._need_agg = True for arg in args: if arg not in self._name_dict: raise Exception(f"Unable to find column named `{arg}`.") column = DfPb.Column() column.name = arg new_df.data.group_by.append(column) return new_df
def agg(self, *args): # 支持如下两种语法 # df.agg(avg("A"), sum("B").alias).show() # df.agg({"A": "avg", "B": "sum"}).show() new_df = copy.deepcopy(self) if len(args) == 0: raise Exception("Nothing to agg.") if len(args) == 1 and isinstance(args[0], dict): args = list(map(lambda key: getattr(AisF, args[0][key])(key), args[0])) if not all(map(lambda arg: isinstance(arg, DfFnColWrapper), args)): raise Exception(f"{args} is not List[DfFnColWrapper].") new_df = new_df._apply_func(*args) return new_df # aggregate functions def sum(self, column): return self._apply_func(AisF.sum(column)) def avg(self, column): return self._apply_func(AisF.avg(column)) def count(self, *, expr="*"): return self._apply_func(AisF.count(expr=expr)) def max(self, column): return self._apply_func(AisF.max(column)) def min(self, column): return self._apply_func(AisF.min(column)) def any(self, column): return self._apply_func(AisF.any(column)) def stddevPop(self, column): return self._apply_func(AisF.stddevPop(column)) def stddevSamp(self, column): return self._apply_func(AisF.stddevSamp(column)) def varPop(self, column): return self._apply_func(AisF.varPop(column)) def varSamp(self, column): return self._apply_func(AisF.varSamp(column)) def corr(self, x, y): return self._apply_func(AisF.corr(x, y)) def covarPop(self, x, y): return self._apply_func(AisF.covarPop(x, y)) def covarSamp(self, x, y): return self._apply_func(AisF.covarSamp(x, y)) def anyLast(self, x, y): return self._apply_func(AisF.anyLast(x, y)) def anyMin(self, x, y): return self._apply_func(AisF.anyMin(x, y)) def anyMax(self, x, y): return self._apply_func(AisF.anyMax(x, y)) def kolmogorov_smirnov_test(self, x, y): return self._apply_func(AisStatisticsF.kolmogorov_smirnov_test(x, y)) def student_ttest(self, x, y): return self._apply_func(AisStatisticsF.student_ttest(x, y)) def welch_ttest(self, x, y): return self._apply_func(AisStatisticsF.welch_ttest(x, y)) def mean_z_test( self, sample_data, sample_index, population_variance_x, population_variance_y, confidence_level, ): return self._apply_func( AisStatisticsF.mean_z_test( sample_data, sample_index, population_variance_x, population_variance_y, confidence_level, ) ) def quantile(self, x, level=None): if level is None: raise Exception("param `level` is not set") return self._apply_func(AisF.quantile(x, level=level)) # all in sql functions def delta_method(self, column, std="True"): return self._apply_func(AisStatisticsF.delta_method(expr=column, std=std)) def ttest_1samp(self, Y, alternative="two-sided", mu=0, X=""): return self._apply_func(AisStatisticsF.ttest_1samp(Y, alternative, mu, X)) def ttest_2samp(self, Y, index, alternative="two-sided", X="",pse=''): return self._apply_func( AisStatisticsF.ttest_2samp(Y, index, alternative=alternative, X=X,pse=pse) ) def xexpt_ttest_2samp( self, numerator, denominator, index, uin, metric_type="avg", group_buckets="[1:1]", alpha=0.05, MDE=0.005, power=0.8, X="", ): return self._apply_func( AisStatisticsF.xexpt_ttest_2samp( numerator, denominator, index, uin, metric_type, group_buckets, alpha, MDE, power, X, ) ) def mann_whitney_utest( self, sample_data, sample_index, alternative="two-sided", continuity_correction=1, ): return self._apply_func( AisStatisticsF.mann_whitney_utest( sample_data, sample_index, alternative, continuity_correction ) ) def srm(self, x, groupby, ratio="[1,1]"): return self._apply_func(AisStatisticsF.srm(x, groupby, ratio)) def ols(self, column, use_bias=True): return self._apply_func(AisRegressionF.ols(column, use_bias)) def wls(self, column, weight, use_bias=True): return self._apply_func(AisRegressionF.wls(column, weight, use_bias)) def stochastic_linear_regression( self, expr, learning_rate=0.00001, l1=0.1, batch_size=15, method="SGD" ): return self._apply_func( AisRegressionF.stochastic_linear_regression( expr, learning_rate, l1, batch_size, method ) ) def stochastic_logistic_regression( self, expr, learning_rate=0.00001, l1=0.1, batch_size=15, method="SGD" ): return self._apply_func( AisRegressionF.stochastic_logistic_regression( expr, learning_rate, l1, batch_size, method ) ) def matrix_multiplication(self, *args, std=False, invert=False): return self._apply_func( AisStatisticsF.matrix_multiplication(*args, std=std, invert=invert) ) def did(self, Y, treatment, time, *X): return self._apply_func(AisRegressionF.did(Y, treatment, time, *X)) def iv_regression(self, formula): return self._apply_func(AisRegressionF.iv_regression(formula)) def boot_strap(self, func, sample_num, bs_num): return self._apply_func(AisStatisticsF.boot_strap(func, sample_num, bs_num)) def permutation(self, func, permutation_num, *col): return self._apply_func(AisStatisticsF.permutation(func, permutation_num, *col)) def checkColumn(self, *args, throw_exception=True): for arg in args: if arg not in self._name_dict: if throw_exception == True: raise Exception("column %s not found" % arg) else: return False return True # 优先提取 alias def getAliasOrName(self): alias_or_name = [] for column in self._select_list: alias_or_name.append(self._get_col_name(column)) return alias_or_name def _find_column(self, name): if self._name_dict.get(name) is not None: # print(f"find column `{name}`: {self._name_dict.get(name).sql(self._ctx)}") return self._name_dict.get(name) for column in self._select_list: col = column.find_column(name) if col is not None: # print(f"find column `{name}`: {col.sql(self._ctx)}") return col # print(f"cannot find column `{name}`, using self.") return DfColumnLeafNode(name) def _unwrap(self, fn_wrapper: DfFnColWrapper): if not isinstance(fn_wrapper, DfFnColWrapper): raise Exception(f"func({type(DfFnColWrapper)}) should be DfFnColWrapper!") fn, params, cols = fn_wrapper.fn, fn_wrapper.params, fn_wrapper.columns alias = fn.alias columns = [] for col in cols: if isinstance(col, int) or isinstance(col, float) or isinstance(col, bool): from_col = DfColumnLeafNode(str(col)) elif isinstance(col, str): from_col = copy.deepcopy(self._find_column(col)) self._name_dict[col] = from_col elif isinstance(col, DfColumnNode): from_col = copy.deepcopy(col) if col.alias is not None: self._name_dict[col.alias] = from_col elif isinstance(col, DfFnColWrapper): from_col = self._unwrap(col) if from_col.alias is not None: self._name_dict[from_col.alias] = from_col else: raise Exception( f"type of col({type(col)}) is neither str nor DfColumn." ) if not isinstance(from_col, DfColumnNode): # we can only apply function on DfColumn but no others raise Exception(f"Logical Error: {type(from_col)} is not DfColumn.") self._name_dict[from_col.sql(self._ctx)] = from_col columns.append(from_col) new_col = DfColumnInternalNode(fn, params, columns, alias) if alias: self._name_dict[alias] = new_col self._name_dict[new_col.sql(self._ctx)] = new_col return new_col def _apply_func(self, *fn_wrappers: Tuple[DfFnColWrapper], keep_old_cols=False): if any(map(lambda fn_wrapper: fn_wrapper.has_agg_func(), fn_wrappers)): if not self._need_agg: new_df = self._transform_to_nested() else: new_df = copy.deepcopy(self) new_df._need_agg = False else: new_df = copy.deepcopy(self) if not keep_old_cols: new_df._select_list = [] for fn_wrapper in fn_wrappers: if not isinstance(fn_wrapper, DfFnColWrapper): raise Exception("fn_wrapper should be a DfFnColWrapper object!") new_col: DfColumnNode = new_df._unwrap(fn_wrapper) new_df._select_list.append(new_col) return new_df def _set_cte(self, cte): self.data.cte = cte def _transform_to_nested(self): new_df = copy.deepcopy(self) subquery = new_df.getExecutedSql() del new_df.data.columns[:] del new_df.data.filters[:] del new_df.data.group_by[:] del new_df.data.order_by[:] new_df.data.limit.limit = 0 del new_df._select_list[:] new_df.data.cte = "" for col_name in new_df._name_dict: new_df._name_dict[col_name] = DfColumnLeafNode(col_name) if new_df.data.source.type == DfPb.SourceType.ClickHouse: new_df.data.source.clickhouse.table_name = subquery new_df.data.source.clickhouse.database = "Nested" elif new_df.data.source.type == DfPb.SourceType.StarRocks: new_df.data.source.starrocks.table_name = subquery new_df.data.source.starrocks.database = "Nested" else: raise Exception("not support source type") return new_df
[docs] def union(self, df): """ This function is used to union two DataFrames. The two DataFrames must have the same number of columns, and the columns must have the same names and order. >>> df1 = df1.union(df2) """ if len(self._select_list) != len(df._select_list): raise Exception("the length of select list not match") for i in range(len(self._select_list)): if self._select_list[i].alias != df._select_list[i].alias: raise Exception("the column name of select list not match") if self._select_list[i].alias == None and self._select_list[i].sql( self._ctx ) != df._select_list[i].sql(df._ctx): raise Exception("the column name of select list not match") new_df = copy.deepcopy(self) subquery1 = self.getExecutedSql() subquery2 = df.getExecutedSql() if new_df.data.source.type == DfPb.SourceType.ClickHouse: new_df.data.source.clickhouse.table_name = f"({subquery1}) union all ({subquery2})" new_df.data.source.clickhouse.database = "Nested" elif new_df.data.source.type == DfPb.SourceType.StarRocks: new_df.data.source.starrocks.table_name = f"({subquery1}) union all ({subquery2})" new_df.data.source.starrocks.database = "Nested" new_df._select_list = [] for col in self._select_list: new_df._select_list.append(new_df._expand_expr(self._get_col_name(col))) return new_df
# parse executed sql, remove limit def getExecutedSql(self): self.__str__() sql = self.data.execute_sql # print(sql) sql = re.sub(r"limit\s+\d+\s*$", "", sql, flags=re.IGNORECASE) return sql # is_temp = True: 每天凌晨 2 点自动删除 def materializedView( self, is_physical_table=False, is_distributed_create=True, is_temp=False, table_name=None, ): materialized_sql = self.getExecutedSql() if table_name is not None: new_df_name = table_name else: new_df_name = DataFrame.createTableName(is_temp) ClickHouseUtils.clickhouse_create_view_v2( table_name=new_df_name, select_statement=materialized_sql, origin_table_name=self.getTableName(), is_physical_table=is_physical_table, is_distributed_create=is_distributed_create, ) return readClickHouse(new_df_name) def createDfRequest(self): df_req = DfPb.DataFrameRequest() df_req.df.CopyFrom(self.data) df_req.task_type = self.task_type df_req.rtx = self.rtx df_req.device_id = self.device_id df_req.database = self.database return df_req def createDataFrameRequestBase64(self): df_req = self.createDfRequest() return base64.b64encode(df_req.SerializeToString()).decode() def createDfRequestJson(self): df_req = self.createDfRequest() return json_format.MessageToJson(df_req) def createCurlRequest(self): return ( 'curl -H "Content-Type: application/json" -X POST -d \'' + self.createDfRequestJson() + "' " + self.url + "/json" ) def execute(self, retry_times=1): logger = get_context().logger while retry_times > 0: try: json_body = self.createDfRequestJson() logger.debug("url= " + self.url + ",data= " + json_body) resp = requests.post( self.url, data=json_body.encode("utf-8"), headers={ "Content-Type": "application/json", "Accept": "application/json", }, ) logger.debug("response=" + resp.text) df_resp = Parse(resp.text, DfPb.DataFrameResponse()) if df_resp.status == DfPb.RetStatus.FAIL: print("Error: ") print(df_resp.msg) elif df_resp.status == DfPb.RetStatus.SUCC: self.data = df_resp.df if not self._select_list: self._select_list = [ DfColumnLeafNode( column_name=col.name, alias=col.alias, type=col.type ) for col in df_resp.df.columns ] for col in self._select_list: self._name_dict[col.sql(self._ctx)] = col return except Exception as e: time.sleep(1) retry_times -= 1 if retry_times == 0: raise e @staticmethod def createTableName(is_temp=False): table_name = "df_table_" if is_temp: table_name += "temp_" table_name += time.strftime("%Y%m%d%H%M%S", time.localtime()) + "_" table_name += str(random.randint(100000, 999999)) return table_name def fill_column_info(self): self.task_type = DfPb.TaskType.FILL_SCHEMA self.execute() # 如果有嵌套子查询,获取最内层的表名 def getTableName(self): return self.data.source.clickhouse.table_name
[docs] def toCsv(self, csv_file_abs_path): """ Convert the data from ClickHouse table to a CSV file. >>> df.toCsv("/path/to/output.csv") """ clickhouse_table_name = self.data.source.clickhouse.table_name clickhouse_2_csv(clickhouse_table_name, csv_file_abs_path)
[docs] def toTdw( self, tdw_database, tdw_table, tdw_user=None, tdw_passward=None, group="tl", is_drop_table=False, overwrite=True, priPart=None, ): """ ClickHouse table >> TDW-thive table. Parameters ---------- :tdw_database (str): The name of the TDW database. :tdw_table (str): The name of the TDW table. :tdw_user (str, optional): The username for TDW. Default is None. :tdw_passward (str, optional): The password for TDW. Default is None. :group (str, optional): The group for TDW. Default is 'tl'. :is_drop_table (bool, optional): Whether to drop the existing TDW table. Default is False. :overwrite (bool, optional): Whether to overwrite the existing data in the TDW table. Default is True. :priPart (str, optional): The primary partition for the TDW table. Default is None. Example ---------- >>> df.toTdw("tdw_database", "tdw_table", group='tl') >>> df.toTdw("tdw_database", "tdw_table", group='tl',priPart=['p_20220222']) """ view = self.materializedView() from fast_causal_inference.util.clickhouse_utils import ClickHouseUtils ClickHouseUtils.clickhouse_2_tdw_v2( view.getTableName(), tdw_database, tdw_table, tdw_user, tdw_passward, group, is_drop_table, overwrite, priPart, )
[docs] def toSparkDf(self): """ ClickHouse table >> spark dataframe. Example ---------- >>> import fast_causal_inference >>> fast_causal_inference.set_default(tenant_id="",tenant_secret_key="") >>> spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html >>> df = fast_causal_inference.readClickHouse('test_data_small') >>> spark_df = df.toSparkDf() """ session = getSparkSession() if self._ctx.engine() == OlapEngineType.CLICKHOUSE: return clickhouse_2_dataframe( session, self.data.source.clickhouse.table_name ) elif self._ctx.engine() == OlapEngineType.STARROCKS: return starrocks_2_dataframe(session, self.data.source.starrocks.table_name) else: raise Exception(f"Olap engine `{self._ctx.engine()}` not supported.")
[docs] def toClickHouse(self, clickhouse_table_name): """ ClickHouse table >> ClickHouse table. Example ---------- >>> df.toClickHouse("new_table") """ clickhouse_table_name = clickhouse_table_name self.materializedView(is_physical_table=True, table_name=clickhouse_table_name)
def getSparkSession(): from fast_causal_inference import get_context if get_context().spark_session == None: print("Spark Session is None, pless init") return return get_context().spark_session # 出入仓
[docs]def readClickHouse(table_name): """ Read data from a ClickHouse table into a DataFrame. >>> import fast_causal_inference >>> df = fast_causal_inference.readClickHouse("test_data_small") """ df = DataFrame() df.data.source.clickhouse.table_name = table_name df.data.source.clickhouse.database = df.database df.fill_column_info() return df
[docs]def readStarRocks(table_name): """ Read data from a StarRocks table into a DataFrame. >>> import fast_causal_inference >>> df = fast_causal_inference.readStarRocks("test_data_small") """ df = DataFrame(olap_engine=OlapEngineType.STARROCKS) df.data.source.starrocks.table_name = table_name df.data.source.starrocks.database = df.database df.fill_column_info() return df
def readTdw( db, table, group="tl", tdw_user=None, tdw_passward=None, priParts=None, str_replace="-1", numeric_replace=0, olap="clickhouse", ): """ Read data from a TDW-thive table into a DataFrame. Parameters ---------- :param db: The name of the TDW database. :param table: The name of the TDW table. :param group: The group for TDW. Default is 'tl'. :param user: The username for TDW. Default is None. :param passwd: The password for TDW. Default is None. :param str_replace: The replacement for string NA values. Default is "-1". :param numeric_replace: The replacement for numeric NA values. Default is 0. Example ------- :: # thive-第一种方式 tdw thive > clickhouse import fast_causal_inference fast_causal_inference.set_default(tenant_id="",tenant_secret_key="") spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html df = fast_causal_inference.readTdw("db_name", "table_name", group="tl") # 读普通表 df = fast_causal_inference.readTdw("db_name", "table_name", group="tl", priPart=['p_20220222']) # 读分区表 # thive-第二种方式 tdw thive > spark > clickhouse import fast_causal_inference fast_causal_inference.set_default(tenant_id="",tenant_secret_key="") spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html from pytoolkit import TDWSQLProvider tdw = TDWSQLProvider(spark, group="tl",db="db_name") spark_df = tdw.table(tblName="allinsql_test_data_small")# 读普通表 spark_df = tdw.table(tblName="allinsql_test_data_small", priPart=['p_20220222'])# 读分区表 spark_df.count() df_ch = fast_causal_inference.readSparkDf(spark_df) """ session = get_context().spark_session from pytoolkit import TDWSQLProvider tdw = TDWSQLProvider( session, group=group, db=db, user=tdw_user, passwd=tdw_passward ) df_new = tdw.table(tblName=table, priParts=priParts) df_new = AisTools.preprocess_na(df_new, str_replace, numeric_replace) df = DataFrame() table_name = DataFrame.createTableName() if olap.lower() == "clickhouse": dataframe_2_clickhouse(dataframe=df_new, clickhouse_table_name=table_name) return readClickHouse(table_name) elif olap.lower() == "starrocks": dataframe_2_starrocks(dataframe=df_new, starrocks_table_name=table_name) return readStarRocks(table_name) else: raise Exception(f"Olap engine `{olap}` not supported.")
[docs]def readSparkDf(dataframe, olap="clickhouse"): """ Read data from a Spark DataFrame into a DataFrame. >>> import fast_causal_inference >>> fast_causal_inference.set_default(tenant_id="",tenant_secret_key="") >>> spark = fast_causal_inference.set_spark_session(group_id='', gaia_id='') # group_id,gaia_id 参考 notebook上默认文件!Spark 资源池.html >>> df = fast_causal_inference.readSparkDf(spark_df) """ df = DataFrame() table_name = DataFrame.createTableName() if olap.lower() == "clickhouse": dataframe_2_clickhouse(dataframe=dataframe, clickhouse_table_name=table_name) return readClickHouse(table_name) elif olap.lower() == "starrocks": dataframe_2_starrocks(dataframe=dataframe, starrocks_table_name=table_name) return readStarRocks(table_name) else: raise Exception(f"Olap engine `{olap}` not supported.")
[docs]def readCsv(csv_file_abs_path, olap="clickhouse"): """ Read data from a CSV file into a DataFrame. >>> import fast_causal_inference >>> df = fast_causal_inference.readCsv("/path/to/file.csv") """ df = DataFrame() table_name = DataFrame.createTableName() if olap.lower() == "clickhouse": csv_2_clickhouse( csv_file_abs_path=csv_file_abs_path, clickhouse_table_name=table_name ) return readClickHouse(table_name) elif olap.lower() == "starrocks": csv_2_starrocks( csv_file_abs_path=csv_file_abs_path, starrocks_table_name=table_name ) return readStarRocks(table_name) else: raise Exception(f"Olap engine `{olap}` not supported.")