提交 63bd38bd 编写于 作者: Q qiaolongfei

code optimize

上级 63bf82dd
......@@ -63,7 +63,9 @@ class VariableResponse {
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
framework::Scope& GetLocalScope() const { return *local_scope_; }
const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
......
......@@ -207,18 +207,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
for (auto &grad_and_id : grad_to_id_str) {
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
for (auto &grad_and_id : grad_to_block_id_str) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_id[pieces[0]] = block_id;
grad_to_block_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
......@@ -232,9 +233,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared;
grad_to_prepared_block;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
grad_to_prepared_block[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
VLOG(3) << "RunAsyncLoop into while";
......@@ -253,8 +254,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
&(v.second->GetLocalScope()));
AsyncExecuteBlock(executor, grad_to_prepared_block[recv_var_name].get(),
v.second->GetMutableLocalScope());
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
......@@ -328,7 +329,7 @@ from send_op and send back variables to recv_op.
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_to_id",
"grad_to_block_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
......
......@@ -137,7 +137,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_id", std::vector<std::string>({""})});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
......
......@@ -185,6 +185,9 @@ class DistributeTranspiler:
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (callable(split_method))
if program is None:
......@@ -479,7 +482,7 @@ class DistributeTranspiler:
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode,
"grad_to_id": grad_to_block_id
"grad_to_block_id": grad_to_block_id
})
pserver_program.sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册