提交 4abef501 编写于 作者: C chengduoZH

code refine

上级 2aaa75ec
...@@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
// the input may have dummy var. // the input and output may have dummy var.
std::vector<VarHandle *> in_var_handle; std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_);
for (auto *in : inputs_) { std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1, PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one."); "The number of input should be one.");
// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(), out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operationplatform::Place
auto &in_place = in_var_handle[0]->place_; // &in_place;
if (in_var_handle[0]->generated_op_) { WaitEvents(out_var_handles, in_var_handle);
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
// auto in_place = in_var_handle[0]->place_;
auto in_scope_idx = in_var_handle[0]->scope_idx_; auto in_scope_idx = in_var_handle[0]->scope_idx_;
auto in_var = auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
...@@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() { ...@@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() {
} }
} }
void BroadcastOpHandle::WaitEvents(
const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle) {
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
}
std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
return in_var_handle;
}
std::string BroadcastOpHandle::Name() const { return "broadcast"; } std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs);
void WaitEvents(const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle);
}; };
} // namespace details } // namespace details
......
...@@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
// the input may have dummy var. // the input and output may have dummy var.
std::vector<VarHandle *> in_var_handles; std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
for (auto *in : inputs_) { std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(), in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one."); "The number of output should be one.");
...@@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() { ...@@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() {
"The place of input and output should be the same."); "The place of input and output should be the same.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
for (auto *in : in_var_handles) { WaitEvents(in_var_handles);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
...@@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() { ...@@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() {
// copy // copy
auto dev_ctx = dev_ctxes_[out_place]; auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] { RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0]; e += in_tensors[j].dims()[0];
...@@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() { ...@@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() {
}); });
} }
void GatherOpHandle::WaitEvents(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
}
std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
return in_var_handles;
}
std::string GatherOpHandle::Name() const { return "gather"; } std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &);
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册