提交 8acad27e 编写于 作者: T typhoonzero

refine code

上级 4b91cb52
......@@ -75,8 +75,8 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_->join();
}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
......@@ -101,7 +101,6 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
......@@ -128,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
if (exit_flag) {
rpc_service_->ShutDown();
}
VLOG(3) << "run optimize graph...";
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear();
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size());
sparse_vars.clear();
} // while(true)
}
......@@ -158,7 +154,6 @@ class ListenAndServOp : public framework::OperatorBase {
protected:
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
std::shared_ptr<std::thread> server_thread_;
mutable std::unordered_map<std::string, int> grads_counter_;
};
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -32,8 +32,8 @@ class RecvOp : public framework::OperatorBase {
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
......@@ -48,8 +48,8 @@ class SendOp : public framework::OperatorBase {
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册