提交 8e3fe2d7 编写于 作者: H heqiaozhi

add skip op

上级 86e1044a
......@@ -95,8 +95,12 @@ void AsyncExecutor::InitParamConfig() {
}
}
_param_config.slot_dim = _param_config.fea_dim - 2; //TODO
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) {
_param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
//sparse
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
......
......@@ -345,12 +345,17 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() {
if (op->Type().find("sgd") != std::string::npos) {
continue;
}
if (op->Type().find("lookup_table") != std::string::npos ||
op->Type().find("lookup_table_grad") != std::string::npos) {
continue;
bool need_skip = false;
for (auto t = 0u; t < _param_config->skip_op.size(); ++t) {
if (op->Type().find(_param_config->skip_op[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op->Run(*thread_scope_, place_);
}
}
UpdateParams();
}
......
......@@ -40,6 +40,8 @@ struct AsyncWorkerParamConfig {
int32_t tmp_push_dense_wait_times;
int32_t tmp_push_sparse_wait_times;
std::vector<std::string> skip_op;
std::map<uint64_t, std::vector<std::string>> dense_variable_name;
std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name;
std::vector<int> dense_table_id;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册