import fast_causal_inference.lib.tools as ais_tools
from fast_causal_inference.dataframe import readClickHouse
from fast_causal_inference.dataframe.functions import (
    DfFnColWrapper,
    register_fn,
    define_args,
    FnArg,
    DfFunction,
    OlapEngineType,
)
[docs]class OneHotEncoder:
    """
    This class implements the OneHotEncoder method for causal inference.
    Parameters
    ----------
    cols : list, default=None
        The columns to be one-hot encoded.
    Methods
    -------
    fit(dataframe):
        Apply the OneHotEncoder method to the input dataframe.
    Example
    -------
    .. code-block:: python
        import fast_causal_inference
        import fast_causal_inference.dataframe.features as Features
        df = fast_causal_inference.readClickHouse('test_data_small')
        one_hot = Features.OneHotEncoder()
        df_new = one_hot.fit(df, cols=['x_cat1'])
        df_new.printSchema()
    """
    def __init__(self):
        pass
    def fit(self, df, cols):
        new_df = df.materializedView(is_temp=True)
        new_table_name = ais_tools.onehot(new_df.getTableName(), cols)
        return readClickHouse(new_table_name[0]) 
from fast_causal_inference.dataframe.functions import (
    DfFnColWrapper,
    register_fn,
    define_args,
    FnArg,
    DfFunction,
    OlapEngineType,
)
@register_fn(engine=OlapEngineType.CLICKHOUSE, name="cutbins")
@register_fn(engine=OlapEngineType.STARROCKS, name="cutbins")
@define_args(
    FnArg(name="column"), FnArg(name="bins"), FnArg(name="if_string", default="True")
)
class CutbinsDfFunction(DfFunction):
    pass
def cut_bins(column, bins, if_string=True):
    bins_str = ""
    if isinstance(bins, str):
        bins_str = bins
    elif isinstance(bins, list):
        bins_str = "[" + ",".join([str(x) for x in bins]) + "]"
    else:
        raise ValueError(f"bins({bins}) must be a str or a list")
    return DfFnColWrapper(CutbinsDfFunction(), {}, [column, bins_str, if_string])
[docs]class Bucketizer:
    """
    This class is used for bucketizing continuous variables into discrete bins.
    """
    def __init__(self):
        pass
[docs]    def fit(self, df, inputCols, splitsArray, outputCols=[], if_string=True):
        """
        This function applies the bucketizing transformation to the specified columns of the input dataframe.
        Parameters
        :param df: The input dataframe to be transformed.
        :type df: DataFrame
        :param inputCols: A list of column names in the dataframe to be bucketized.
        :type inputCols: list
        :param splitsArray: A list of lists, where each inner list contains the split points for bucketizing the corresponding column in inputCols.
        :type splitsArray: list
        :param outputCols: A list of output column names after bucketizing. If not provided, '_buckets' will be appended to the original column names.
        :type outputCols: list, optional
        :param if_string: A flag indicating whether the bin values should be treated as strings. Default is True.
        :type if_string: bool, optional
        :return: The transformed dataframe with bucketized columns.
        :rtype: DataFrame
        Example
        -------
        .. code-block:: python
            >>> import fast_causal_inference
            >>> import fast_causal_inference.dataframe.features as Features
            >>> df = fast_causal_inference.readClickHouse('test_data_small')
            >>> bucketizer = Features.Bucketizer()
            >>> df_new = bucketizer.fit(df,['x1','x2'],[[1,3],[0,2]],if_string=True)
            >>> df_new.select('x1','x2','x1_buckets','x2_buckets').head(5).show()
                                x1            x2 x1_buckets x2_buckets
            0  -0.131301907  -3.152383354          1          0
            1  -0.966931088  -0.427920835          1          0
            2   1.257744217  -2.050358546      [1,3)          0
            3  -0.777228042  -2.621604715          1          0
            4  -0.669571385   0.606404768          1      [0,2)
            >>> df_new = bucketizer.fit(df,['x1','x2'],[[1,3],[0,2]],if_string=False)
            >>> df_new.select('x1','x2','x1_buckets','x2_buckets').head(5).show()
                        x1            x2 x1_buckets x2_buckets
            0  -0.131301907  -3.152383354          1          1
            1  -0.966931088  -0.427920835          1          1
            2   1.257744217  -2.050358546          2          1
            3  -0.777228042  -2.621604715          1          1
            4  -0.669571385   0.606404768          1          2
        """
        if len(outputCols) == 0:
            outputCols = [i + "_buckets" for i in inputCols]
        for i in range(len(inputCols)):
            df = df.withColumn(
                outputCols[i], cut_bins(inputCols[i], splitsArray[i], if_string)
            )
        return df