提交 4abef501 编写于 作者: C chengduoZH

code refine

上级 2aaa75ec
......@@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
// the input and output may have dummy var.
std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_);
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");
// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
// Wait input done, this Wait is asynchronous operationplatform::Place
// &in_place;
WaitEvents(out_var_handles, in_var_handle);
//
auto in_place = in_var_handle[0]->place_;
auto in_scope_idx = in_var_handle[0]->scope_idx_;
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
......@@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() {
}
}
void BroadcastOpHandle::WaitEvents(
const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle) {
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
}
std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
return in_var_handle;
}
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details
} // namespace framework
......
......@@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs);
void WaitEvents(const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle);
};
} // namespace details
......
......@@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs_) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
// the input and output may have dummy var.
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
......@@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() {
"The place of input and output should be the same.");
// Wait input done, this Wait is asynchronous operation
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
WaitEvents(in_var_handles);
std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors;
......@@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() {
// copy
auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] {
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
......@@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() {
});
}
void GatherOpHandle::WaitEvents(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
}
std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
return in_var_handles;
}
std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details
} // namespace framework
......
......@@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &);
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
};
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册