From 8d1dae46ac0b526b247c0c0ad835a2737924f941 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Wed, 17 Jun 2020 23:28:01 +0800 Subject: [PATCH] Throw error when load config failed --- mindspore/ccsrc/dataset/api/python_bindings.cc | 2 +- mindspore/ccsrc/dataset/core/config_manager.cc | 9 +++------ tests/ut/data/dataset/declient_filter.cfg | 3 --- tests/ut/python/dataset/test_filterop.py | 5 ----- 4 files changed, 4 insertions(+), 15 deletions(-) delete mode 100644 tests/ut/data/dataset/declient_filter.cfg diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 6ed24e13a..c0add5a89 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -276,7 +276,7 @@ void bindTensor(py::module *m) { .def("get_op_connector_size", &ConfigManager::op_connector_size) .def("get_seed", &ConfigManager::seed) .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) - .def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); }); + .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); (void)py::class_>(*m, "Tensor", py::buffer_protocol()) .def(py::init([](py::array arr) { diff --git a/mindspore/ccsrc/dataset/core/config_manager.cc b/mindspore/ccsrc/dataset/core/config_manager.cc index 8732a4c4b..a489b4a4c 100644 --- a/mindspore/ccsrc/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/dataset/core/config_manager.cc @@ -48,7 +48,7 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { Status ConfigManager::LoadFile(const std::string &settingsFile) { Status rc; if (!Path(settingsFile).Exists()) { - RETURN_STATUS_UNEXPECTED("File is not found"); + RETURN_STATUS_UNEXPECTED("File is not found."); } // Some settings are mandatory, others are not (with default). If a setting // is optional it will set a default value if the config is missing from the file. @@ -59,14 +59,11 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) { rc = FromJson(js); } catch (const nlohmann::json::type_error &e) { std::ostringstream ss; - ss << "Client settings failed to load:\n" << e.what(); + ss << "Client file failed to load:\n" << e.what(); std::string err_msg = ss.str(); RETURN_STATUS_UNEXPECTED(err_msg); } catch (const std::exception &err) { - std::ostringstream ss; - ss << "Client settings failed to load:\n" << err.what(); - std::string err_msg = ss.str(); - RETURN_STATUS_UNEXPECTED(err_msg); + RETURN_STATUS_UNEXPECTED("Client file failed to load."); } return rc; } diff --git a/tests/ut/data/dataset/declient_filter.cfg b/tests/ut/data/dataset/declient_filter.cfg deleted file mode 100644 index 89e1199f5..000000000 --- a/tests/ut/data/dataset/declient_filter.cfg +++ /dev/null @@ -1,3 +0,0 @@ -{ - "rowsPerBuffer": 10, -} diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py index 25fde151e..015d58037 100644 --- a/tests/ut/python/dataset/test_filterop.py +++ b/tests/ut/python/dataset/test_filterop.py @@ -390,7 +390,6 @@ def filter_func_Partial_0(col1, col2, col3, col4): # test with row_data_buffer > 1 def test_filter_by_generator_Partial0(): - ds.config.load('../data/dataset/declient_filter.cfg') dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"]) dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"]) dataset_zip = ds.zip((dataset1, dataset2)) @@ -404,7 +403,6 @@ def test_filter_by_generator_Partial0(): # test with row_data_buffer > 1 def test_filter_by_generator_Partial1(): - ds.config.load('../data/dataset/declient_filter.cfg') dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"]) dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"]) dataset_zip = ds.zip((dataset1, dataset2)) @@ -419,7 +417,6 @@ def test_filter_by_generator_Partial1(): # test with row_data_buffer > 1 def test_filter_by_generator_Partial2(): - ds.config.load('../data/dataset/declient_filter.cfg') dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"]) dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"]) @@ -454,7 +451,6 @@ def generator_big(maxid=20): # test with row_data_buffer > 1 def test_filter_by_generator_Partial(): - ds.config.load('../data/dataset/declient_filter.cfg') dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"]) dataset_s = dataset.shuffle(4) dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) @@ -473,7 +469,6 @@ def filter_func_cifar(col1, col2): # test with cifar10 def test_filte_case_dataset_cifar10(): DATA_DIR_10 = "../data/dataset/testCifar10Data" - ds.config.load('../data/dataset/declient_filter.cfg') dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False) dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) for item in dataset_f1.create_dict_iterator(): -- GitLab