提交 8d1dae46 编写于 作者: M ms_yan

Throw error when load config failed

上级 19e66f06
......@@ -276,7 +276,7 @@ void bindTensor(py::module *m) {
.def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_seed", &ConfigManager::seed)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
.def(py::init([](py::array arr) {
......
......@@ -48,7 +48,7 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
Status ConfigManager::LoadFile(const std::string &settingsFile) {
Status rc;
if (!Path(settingsFile).Exists()) {
RETURN_STATUS_UNEXPECTED("File is not found");
RETURN_STATUS_UNEXPECTED("File is not found.");
}
// Some settings are mandatory, others are not (with default). If a setting
// is optional it will set a default value if the config is missing from the file.
......@@ -59,14 +59,11 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) {
rc = FromJson(js);
} catch (const nlohmann::json::type_error &e) {
std::ostringstream ss;
ss << "Client settings failed to load:\n" << e.what();
ss << "Client file failed to load:\n" << e.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
} catch (const std::exception &err) {
std::ostringstream ss;
ss << "Client settings failed to load:\n" << err.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
RETURN_STATUS_UNEXPECTED("Client file failed to load.");
}
return rc;
}
......
......@@ -390,7 +390,6 @@ def filter_func_Partial_0(col1, col2, col3, col4):
# test with row_data_buffer > 1
def test_filter_by_generator_Partial0():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
dataset_zip = ds.zip((dataset1, dataset2))
......@@ -404,7 +403,6 @@ def test_filter_by_generator_Partial0():
# test with row_data_buffer > 1
def test_filter_by_generator_Partial1():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
dataset_zip = ds.zip((dataset1, dataset2))
......@@ -419,7 +417,6 @@ def test_filter_by_generator_Partial1():
# test with row_data_buffer > 1
def test_filter_by_generator_Partial2():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
......@@ -454,7 +451,6 @@ def generator_big(maxid=20):
# test with row_data_buffer > 1
def test_filter_by_generator_Partial():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"])
dataset_s = dataset.shuffle(4)
dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1)
......@@ -473,7 +469,6 @@ def filter_func_cifar(col1, col2):
# test with cifar10
def test_filte_case_dataset_cifar10():
DATA_DIR_10 = "../data/dataset/testCifar10Data"
ds.config.load('../data/dataset/declient_filter.cfg')
dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False)
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
for item in dataset_f1.create_dict_iterator():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册