提交 63bd38bd 编写于 作者: Q qiaolongfei

code optimize

上级 63bf82dd
...@@ -63,7 +63,9 @@ class VariableResponse { ...@@ -63,7 +63,9 @@ class VariableResponse {
// other: number of error field. // other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer); int Parse(const ::grpc::ByteBuffer& byte_buffer);
framework::Scope& GetLocalScope() const { return *local_scope_; } const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); } inline std::string OutVarname() { return meta_.out_varname(); }
......
...@@ -207,18 +207,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -207,18 +207,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::BlockDesc *prefetch_block) const { framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in"; VLOG(3) << "RunAsyncLoop in";
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id; std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id"); auto grad_to_block_id_str =
for (auto &grad_and_id : grad_to_id_str) { Attr<std::vector<std::string>>("grad_to_block_id");
for (auto &grad_and_id : grad_to_block_id_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces); split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
grad_to_id[pieces[0]] = block_id; grad_to_block_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0]; id_to_grad[block_id] = pieces[0];
} }
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
...@@ -232,9 +233,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -232,9 +233,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared; grad_to_prepared_block;
for (size_t i = 0; i < block_list.size(); ++i) { for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; grad_to_prepared_block[id_to_grad[block_list[i]]] = optimize_prepared[i];
} }
VLOG(3) << "RunAsyncLoop into while"; VLOG(3) << "RunAsyncLoop into while";
...@@ -253,8 +254,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -253,8 +254,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
LOG(ERROR) << "Can not find server side var: " << recv_var_name; LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var"); PADDLE_THROW("Can not find server side var");
} }
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(), AsyncExecuteBlock(executor, grad_to_prepared_block[recv_var_name].get(),
&(v.second->GetLocalScope())); v.second->GetMutableLocalScope());
// TODO(qiao): explain why // TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
...@@ -328,7 +329,7 @@ from send_op and send back variables to recv_op. ...@@ -328,7 +329,7 @@ from send_op and send back variables to recv_op.
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"grad_to_id", "grad_to_block_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
......
...@@ -137,7 +137,7 @@ void StartServerNet(bool is_sparse) { ...@@ -137,7 +137,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_id", std::vector<std::string>({""})}); attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true}); attrs.insert({"sync_mode", true});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
......
...@@ -185,6 +185,9 @@ class DistributeTranspiler: ...@@ -185,6 +185,9 @@ class DistributeTranspiler:
:param split_method: A function to determin how to split variables :param split_method: A function to determin how to split variables
to different servers equally. to different servers equally.
:type split_method: function :type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
""" """
assert (callable(split_method)) assert (callable(split_method))
if program is None: if program is None:
...@@ -479,7 +482,7 @@ class DistributeTranspiler: ...@@ -479,7 +482,7 @@ class DistributeTranspiler:
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block, "PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_id": grad_to_block_id "grad_to_block_id": grad_to_block_id
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册