From 8e3fe2d7355c09a3dde09bcbf63971ff3bfe169d Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Mon, 10 Dec 2018 18:57:57 +0800 Subject: [PATCH] add skip op --- paddle/fluid/framework/async_executor.cc | 8 ++++++-- paddle/fluid/framework/executor_thread_worker.cc | 15 ++++++++++----- paddle/fluid/framework/executor_thread_worker.h | 2 ++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 7685883dd..f96ff436d 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -95,8 +95,12 @@ void AsyncExecutor::InitParamConfig() { } } _param_config.slot_dim = _param_config.fea_dim - 2; //TODO - _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch()); - _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); + _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); + _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch()); + + for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) { + _param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t)); + } //sparse for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) { auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t); diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index e0ee9c11c..d8320b422 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -340,16 +340,21 @@ void AsyncExecutorThreadWorker::SetPullDenseThread(std::shared_ptrType().find("sgd") != std::string::npos) { continue; } - if (op->Type().find("lookup_table") != std::string::npos || - op->Type().find("lookup_table_grad") != std::string::npos) { - continue; + bool need_skip = false; + for (auto t = 0u; t < _param_config->skip_op.size(); ++t) { + if (op->Type().find(_param_config->skip_op[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*thread_scope_, place_); } - op->Run(*thread_scope_, place_); } UpdateParams(); } diff --git a/paddle/fluid/framework/executor_thread_worker.h b/paddle/fluid/framework/executor_thread_worker.h index 4e3255a59..b3ee9dfae 100644 --- a/paddle/fluid/framework/executor_thread_worker.h +++ b/paddle/fluid/framework/executor_thread_worker.h @@ -39,6 +39,8 @@ struct AsyncWorkerParamConfig { int fea_dim; int32_t tmp_push_dense_wait_times; int32_t tmp_push_sparse_wait_times; + + std::vector skip_op; std::map> dense_variable_name; std::map> dense_gradient_variable_name; -- GitLab