提交 8e3fe2d7 编写于 作者: H heqiaozhi

add skip op

上级 86e1044a
...@@ -95,8 +95,12 @@ void AsyncExecutor::InitParamConfig() { ...@@ -95,8 +95,12 @@ void AsyncExecutor::InitParamConfig() {
} }
} }
_param_config.slot_dim = _param_config.fea_dim - 2; //TODO _param_config.slot_dim = _param_config.fea_dim - 2; //TODO
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch()); _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) {
_param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
//sparse //sparse
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) { for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t); auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
......
...@@ -345,12 +345,17 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() { ...@@ -345,12 +345,17 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() {
if (op->Type().find("sgd") != std::string::npos) { if (op->Type().find("sgd") != std::string::npos) {
continue; continue;
} }
if (op->Type().find("lookup_table") != std::string::npos || bool need_skip = false;
op->Type().find("lookup_table_grad") != std::string::npos) { for (auto t = 0u; t < _param_config->skip_op.size(); ++t) {
continue; if (op->Type().find(_param_config->skip_op[t]) != std::string::npos) {
need_skip = true;
break;
}
} }
if (!need_skip) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
}
UpdateParams(); UpdateParams();
} }
......
...@@ -40,6 +40,8 @@ struct AsyncWorkerParamConfig { ...@@ -40,6 +40,8 @@ struct AsyncWorkerParamConfig {
int32_t tmp_push_dense_wait_times; int32_t tmp_push_dense_wait_times;
int32_t tmp_push_sparse_wait_times; int32_t tmp_push_sparse_wait_times;
std::vector<std::string> skip_op;
std::map<uint64_t, std::vector<std::string>> dense_variable_name; std::map<uint64_t, std::vector<std::string>> dense_variable_name;
std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name; std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name;
std::vector<int> dense_table_id; std::vector<int> dense_table_id;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册