未验证 提交 209075a4 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Add consistency insepection of use_var_list and data_generator...

[CPU-PSLIB] Add consistency insepection of use_var_list and data_generator data, test=develop (#34463)
上级 8967a66a
...@@ -255,6 +255,71 @@ class DatasetBase(object): ...@@ -255,6 +255,71 @@ class DatasetBase(object):
def _dynamic_adjust_after_train(self): def _dynamic_adjust_after_train(self):
pass pass
def _check_use_var_with_data_generator(self, var_list, data_generator_class,
test_file):
"""
Var consistency insepection of use_var_list and data_generator data.
Examples:
.. code-block:: python
# required: skiptest
import paddle
from dataset_generator import CTRDataset
dataset = paddle.distributed.fleet.DatasetBase()
generator_class = CTRDataset()
dataset._check_use_var_with_data_generator([data, label], generator_class, "data/part-00000")
Args:
var_list(list): variable list
data_generator_class(class): data_generator class
test_file(str): local test file path
"""
f = open(test_file, "r")
var_len = len(var_list)
while True:
line = f.readline()
if line:
line_iter = data_generator_class.generate_sample(line)
for user_parsed_line in line_iter():
data_gen_len = len(user_parsed_line)
if var_len != data_gen_len:
raise ValueError(
"var length mismatch error: var_list = %s vs data_generator = %s"
% (var_len, data_gen_len))
for i, ele in enumerate(user_parsed_line):
if len(ele[1]) == 0:
raise ValueError(
"var length error: var %s's length in data_generator is 0"
% ele[0])
if var_list[
i].dtype == core.VarDesc.VarType.FP32 and not all(
isinstance(ele, float) for ele in ele[1]):
raise TypeError(
"var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-float value, which is %s \n"
"Please check if order of var_list and data_generator are aligned. \n"
"Please check if var's type in data_generator is correct."
% (ele[0], "float", ele[1]))
if (var_list[i].dtype == core.VarDesc.VarType.INT64 or
var_list[i].dtype == core.VarDesc.VarType.INT32
) and not all(
isinstance(ele, int) for ele in ele[1]):
raise TypeError(
"var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-int value, which is %s \n"
"Please check if order of var_list and data_generator are aligned. \n"
"Please check if var's type in data_generator is correct."
% (ele[0], "int", ele[1]))
else:
break
f.close()
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册