提交 7dc7990d 编写于 作者: X xiefangqi

add NumWorkers validate

上级 d00f7d8f
......@@ -309,6 +309,19 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers) {
#if !defined(_WIN32) && !defined(_WIN64)
#ifndef ENABLE_ANDROID
int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
if (cpu_count < 0 || cpu_count > INT32_MAX) {
MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
return nullptr;
}
if (num_workers < 1 || num_workers > cpu_count) {
MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
return nullptr;
}
#endif
#endif
num_workers_ = num_workers;
return shared_from_this();
}
......@@ -336,7 +349,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
/// can be set to default, which corresponds to 0/total_words separately
/// \param[in] top_k Number of words to be built into vocab. top_k most frequent words are
// taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
/// taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
/// \param[in] special_tokens A list of strings, each one is a special token
/// \param[in] special_first Whether special_tokens will be prepended/appended to vocab, If special_tokens
/// is specified and special_first is set to default, special_tokens will be prepended
......
......@@ -555,7 +555,7 @@ def check_map(method):
callbacks], _ = \
parse_user_args(method, *args, **kwargs)
nreq_param_columns = ['input_columns', 'output_columns']
nreq_param_columns = ['input_columns', 'output_columns', 'columns_order']
if columns_order is not None:
type_check(columns_order, (list,), "columns_order")
......@@ -571,7 +571,7 @@ def check_map(method):
else:
type_check(callbacks, (callback.DSCallback,), "callbacks")
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, columns_order]):
if param is not None:
check_columns(param, param_name)
if callbacks is not None:
......
......@@ -950,3 +950,25 @@ TEST_F(MindDataTestPipeline, TestZipSuccess2) {
// Manually terminate the pipeline
iter->Stop();
}
#if !defined(_WIN32) && !defined(_WIN64)
#ifndef ENABLE_ANDROID
TEST_F(MindDataTestPipeline, TestNumWorkersValidate) {
// Testing the static zip() function
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNumWorkersValidate.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 9));
EXPECT_NE(ds, nullptr);
// test if set num_workers=-1
std::shared_ptr<Dataset> ds1 = ds->SetNumWorkers(-1);
EXPECT_EQ(ds1, nullptr);
// test if set num_workers>cpu_count
std::shared_ptr<Dataset> ds2 = ds->SetNumWorkers(UINT32_MAX);
EXPECT_EQ(ds2, nullptr);
}
#endif
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册