Skip to content

DataManager

BaseDataManager

Bases: ABC

Base data manager for loading and saving data.

Source code in utu/eval/data/data_manager.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class BaseDataManager(abc.ABC):
    """Base data manager for loading and saving data."""

    data: list[EvaluationSample]

    def __init__(self, config: EvalConfig) -> None:
        self.config = config

    @abc.abstractmethod
    def load(self) -> list[EvaluationSample]:
        """Load the dataset."""
        raise NotImplementedError

    @abc.abstractmethod
    def save(self, **kwargs) -> None:
        """Save the dataset."""
        raise NotImplementedError

    @abc.abstractmethod
    def get_samples(self, stage: Literal["init", "rollout", "judged"] = None) -> list[EvaluationSample]:
        """Get samples of specified stage from the dataset."""
        raise NotImplementedError

load abstractmethod

load() -> list[EvaluationSample]

Load the dataset.

Source code in utu/eval/data/data_manager.py
24
25
26
27
@abc.abstractmethod
def load(self) -> list[EvaluationSample]:
    """Load the dataset."""
    raise NotImplementedError

save abstractmethod

save(**kwargs) -> None

Save the dataset.

Source code in utu/eval/data/data_manager.py
29
30
31
32
@abc.abstractmethod
def save(self, **kwargs) -> None:
    """Save the dataset."""
    raise NotImplementedError

get_samples abstractmethod

get_samples(
    stage: Literal["init", "rollout", "judged"] = None,
) -> list[EvaluationSample]

Get samples of specified stage from the dataset.

Source code in utu/eval/data/data_manager.py
34
35
36
37
@abc.abstractmethod
def get_samples(self, stage: Literal["init", "rollout", "judged"] = None) -> list[EvaluationSample]:
    """Get samples of specified stage from the dataset."""
    raise NotImplementedError

FileDataManager

Bases: BaseDataManager

File data manager for loading and saving data.

Source code in utu/eval/data/data_manager.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class FileDataManager(BaseDataManager):
    """File data manager for loading and saving data."""

    def load(self) -> list[EvaluationSample]:
        """Load raw data from the specified dataset."""
        data_path = self._get_data_path()
        samples = []
        with open(data_path, encoding="utf-8") as f:
            for line in f:
                data = json.loads(line.strip())
                # assert "source" in data, f"Missing source in data: {data}"
                # assert data["source"].lower() in DATA_PROCESSER_FACTORY._registry, f"Unknown source: {data['source']}"
                sample = EvaluationSample(
                    source=data.get("source", self.config.data.dataset),
                    raw_question=data.get(self.config.data.question_field, ""),
                    level=data.get("level", 0),  # if applicable
                    correct_answer=data.get(self.config.data.gt_field, ""),
                    file_name=data.get("file name", ""),  # for GAIA
                    exp_id=self.config.exp_id,  # add exp_id
                )
                samples.append(sample)
        self.data = samples
        return samples

    def _get_data_path(self) -> pathlib.Path:
        if self.config.data.type == "single" and self.config.data.dataset in BUILTIN_BENCHMARKS:
            data_path = pathlib.Path(BUILTIN_BENCHMARKS[self.config.data.dataset]["data_path"])
        else:
            data_path = pathlib.Path(self.config.data.dataset)
        assert data_path.exists(), f"Data file {data_path} does not exist."
        assert str(data_path).endswith(".jsonl"), f"Only support .jsonl files, but got {data_path}."
        return data_path

    def get_samples(self, stage: Literal["init", "rollout", "judged"] = None) -> list[EvaluationSample]:
        return [d for d in self.data if d.stage == stage]

    def save(self, ofn: str) -> None:
        with open(ofn, "w", encoding="utf-8") as f:
            for sample in self.data:
                f.write(json.dumps(sample.as_dict()) + "\n")

load

load() -> list[EvaluationSample]

Load raw data from the specified dataset.

Source code in utu/eval/data/data_manager.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def load(self) -> list[EvaluationSample]:
    """Load raw data from the specified dataset."""
    data_path = self._get_data_path()
    samples = []
    with open(data_path, encoding="utf-8") as f:
        for line in f:
            data = json.loads(line.strip())
            # assert "source" in data, f"Missing source in data: {data}"
            # assert data["source"].lower() in DATA_PROCESSER_FACTORY._registry, f"Unknown source: {data['source']}"
            sample = EvaluationSample(
                source=data.get("source", self.config.data.dataset),
                raw_question=data.get(self.config.data.question_field, ""),
                level=data.get("level", 0),  # if applicable
                correct_answer=data.get(self.config.data.gt_field, ""),
                file_name=data.get("file name", ""),  # for GAIA
                exp_id=self.config.exp_id,  # add exp_id
            )
            samples.append(sample)
    self.data = samples
    return samples

DBDataManager

Bases: FileDataManager

Database data manager for loading and saving data.

Source code in utu/eval/data/data_manager.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class DBDataManager(FileDataManager):
    """Database data manager for loading and saving data."""

    def __init__(self, config: EvalConfig) -> None:
        self.config = config

    def load(self) -> list[EvaluationSample]:
        if self._check_exp_id():
            logger.warning(f"exp_id {self.config.exp_id} already exists in db")
            return self.get_samples()

        with SQLModelUtils.create_session() as session:
            datapoints = session.exec(
                select(DatasetSample).where(DatasetSample.dataset == self.config.data.dataset)
            ).all()
            logger.info(f"Loaded {len(datapoints)} samples from {self.config.data.dataset}.")
            samples = []
            for dp in datapoints:
                sample = EvaluationSample(
                    dataset=dp.dataset,
                    dataset_index=dp.index,
                    source=dp.source,
                    raw_question=dp.question,
                    level=dp.level,
                    correct_answer=dp.answer,
                    file_name=dp.file_name,
                    meta=dp.meta,
                    exp_id=self.config.exp_id,  # add exp_id
                )
                samples.append(sample)

            self.data = samples
            self.save(self.data)  # save to db
            return self.data

    def get_samples(
        self, stage: Literal["init", "rollout", "judged"] = None, limit: int = None
    ) -> list[EvaluationSample]:
        """Get samples from exp_id with specified stage."""
        with SQLModelUtils.create_session() as session:
            samples = session.exec(
                select(EvaluationSample)
                .where(
                    EvaluationSample.exp_id == self.config.exp_id,
                    EvaluationSample.stage == stage if stage else True,
                )
                .order_by(EvaluationSample.dataset_index)
                .limit(limit)
            ).all()
            return samples

    def save(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
        """Update or add sample(s) to db."""
        if isinstance(samples, list):
            with SQLModelUtils.create_session() as session:
                session.add_all(samples)
                session.commit()
        else:
            with SQLModelUtils.create_session() as session:
                session.add(samples)
                session.commit()

    def delete_samples(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
        """Delete sample(s) from db."""
        if isinstance(samples, list):
            with SQLModelUtils.create_session() as session:
                for sample in samples:
                    session.delete(sample)
                session.commit()
        else:
            with SQLModelUtils.create_session() as session:
                session.delete(samples)
                session.commit()

    def _check_exp_id(self) -> bool:
        # check if any record has the same exp_id
        with SQLModelUtils.create_session() as session:
            has_exp_id = session.exec(
                select(EvaluationSample).where(EvaluationSample.exp_id == self.config.exp_id)
            ).first()
        return has_exp_id is not None

get_samples

get_samples(
    stage: Literal["init", "rollout", "judged"] = None,
    limit: int = None,
) -> list[EvaluationSample]

Get samples from exp_id with specified stage.

Source code in utu/eval/data/data_manager.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def get_samples(
    self, stage: Literal["init", "rollout", "judged"] = None, limit: int = None
) -> list[EvaluationSample]:
    """Get samples from exp_id with specified stage."""
    with SQLModelUtils.create_session() as session:
        samples = session.exec(
            select(EvaluationSample)
            .where(
                EvaluationSample.exp_id == self.config.exp_id,
                EvaluationSample.stage == stage if stage else True,
            )
            .order_by(EvaluationSample.dataset_index)
            .limit(limit)
        ).all()
        return samples

save

save(
    samples: list[EvaluationSample] | EvaluationSample,
) -> None

Update or add sample(s) to db.

Source code in utu/eval/data/data_manager.py
133
134
135
136
137
138
139
140
141
142
def save(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
    """Update or add sample(s) to db."""
    if isinstance(samples, list):
        with SQLModelUtils.create_session() as session:
            session.add_all(samples)
            session.commit()
    else:
        with SQLModelUtils.create_session() as session:
            session.add(samples)
            session.commit()

delete_samples

delete_samples(
    samples: list[EvaluationSample] | EvaluationSample,
) -> None

Delete sample(s) from db.

Source code in utu/eval/data/data_manager.py
144
145
146
147
148
149
150
151
152
153
154
def delete_samples(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
    """Delete sample(s) from db."""
    if isinstance(samples, list):
        with SQLModelUtils.create_session() as session:
            for sample in samples:
                session.delete(sample)
            session.commit()
    else:
        with SQLModelUtils.create_session() as session:
            session.delete(samples)
            session.commit()