diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 7685883dd5e5e42a9eb9aa749ecc1e3c322d9413..f96ff436da997e2035f02b4982c5acd770b024e3 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -95,8 +95,12 @@ void AsyncExecutor::InitParamConfig() { } } _param_config.slot_dim = _param_config.fea_dim - 2; //TODO - _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch()); - _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); + _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); + _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch()); + + for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) { + _param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t)); + } //sparse for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) { auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t); diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index e0ee9c11c909bd5f760cf37fb508a57cb6aad1f2..d8320b422b800d48c4dd562b438ef81b297b4ec5 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -340,16 +340,21 @@ void AsyncExecutorThreadWorker::SetPullDenseThread(std::shared_ptrType().find("sgd") != std::string::npos) { continue; } - if (op->Type().find("lookup_table") != std::string::npos || - op->Type().find("lookup_table_grad") != std::string::npos) { - continue; + bool need_skip = false; + for (auto t = 0u; t < _param_config->skip_op.size(); ++t) { + if (op->Type().find(_param_config->skip_op[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*thread_scope_, place_); } - op->Run(*thread_scope_, place_); } UpdateParams(); } diff --git a/paddle/fluid/framework/executor_thread_worker.h b/paddle/fluid/framework/executor_thread_worker.h index 4e3255a590cce336070b8385c455b6bd18d1a220..b3ee9dfaec9953f11360eba30dc98fafb8076b79 100644 --- a/paddle/fluid/framework/executor_thread_worker.h +++ b/paddle/fluid/framework/executor_thread_worker.h @@ -39,6 +39,8 @@ struct AsyncWorkerParamConfig { int fea_dim; int32_t tmp_push_dense_wait_times; int32_t tmp_push_sparse_wait_times; + + std::vector skip_op; std::map> dense_variable_name; std::map> dense_gradient_variable_name;