提交 540b4535 编写于 作者: Y Yancey1989

use req_count as atomic type

上级 6debbcd9
...@@ -298,16 +298,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, ...@@ -298,16 +298,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, op, dev_id);
} }
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -107,12 +107,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -107,12 +107,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient( bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const; const std::string &og) const;
......
...@@ -59,7 +59,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -59,7 +59,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
}); });
req_count_++; req_count_++;
return true; return true;
......
...@@ -197,7 +197,7 @@ class RPCClient { ...@@ -197,7 +197,7 @@ class RPCClient {
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0; std::atomic<int64_t> req_count_{0};
std::mutex mutex_; std::mutex mutex_;
}; };
......
...@@ -403,7 +403,6 @@ class DistributeTranspiler: ...@@ -403,7 +403,6 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
# TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist) eplist)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册