提交 540b4535 编写于 作者: Y Yancey1989

use req_count as atomic type

上级 6debbcd9
......@@ -298,16 +298,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id);
}
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
......
......@@ -107,12 +107,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const;
......
......@@ -59,7 +59,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
});
req_count_++;
return true;
......
......@@ -197,7 +197,7 @@ class RPCClient {
private:
grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0;
std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
};
......
......@@ -403,7 +403,6 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]},
attrs={"axis": 0})
# TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册