提交 881e063e 编写于 作者: C chengduoZH

follow comments

上级 ff599b92
...@@ -53,42 +53,39 @@ void BroadcastOpHandle::RunImpl() { ...@@ -53,42 +53,39 @@ void BroadcastOpHandle::RunImpl() {
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
// NOTE(zcd): the Place of input can get from in_tensor and in_var_handle , // NOTE: The tensors' Place of input and output must be all on GPU or all on
// maybe they are different, because the Place that getting from in_tensor is // CPU.
// determined at runtime, the other is determined at building SSA graph stage.
// If they are different, DataTransform should be applied. Currently, it has
// not been done yet.
for (auto *out_var_handle : out_var_handles) { for (auto *out_var_handle : out_var_handles) {
if (*out_var_handle == *in_var_handle) { if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue; continue;
} }
auto &out_p = out_var_handle->place_; auto t_out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_) auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_EQ( if (platform::is_gpu_place(in_tensor.place())) {
out_p.which(), in_tensor.place().which(), PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Currently, Places of input and output must be all on CPU " "Places of input and output must be all on GPU.");
"or all on GPU."); } else {
t_out_p = platform::CPUPlace();
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var); VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type()); in_tensor.type());
} }
if (platform::is_cpu_place(in_tensor.place())) { if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) { for (auto *out_var_handle : out_var_handles) {
if (*out_var_handle == *in_var_handle) { if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue; continue;
} }
auto &out_p = out_var_handle->place_; auto &out_p = out_var_handle->place_;
auto dev_ctx = dev_ctxes_.at(out_p);
auto *out_var = var_scopes.at(out_var_handle->scope_idx_) auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name_);
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { RunAndRecordEvent(out_p, [in_tensor, out_var] {
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
in_tensor, out_p, *dev_ctx, in_tensor, platform::CPUPlace(),
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
}); });
} }
...@@ -134,8 +131,8 @@ void BroadcastOpHandle::RunImpl() { ...@@ -134,8 +131,8 @@ void BroadcastOpHandle::RunImpl() {
call(); call();
} }
} }
// TODO(zcd): Maybe the unequal operator is not appropriate here.
if (*out_handle != *in_var_handle) { if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_) auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_); ->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
......
...@@ -75,14 +75,15 @@ void GatherOpHandle::RunImpl() { ...@@ -75,14 +75,15 @@ void GatherOpHandle::RunImpl() {
in_tensors.emplace_back(in_sr_value.value()); in_tensors.emplace_back(in_sr_value.value());
} }
// TODO(zcd): The Place of var_handle is determined at building SSA graph // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
// stage, while the Place of var is determined at runtime. If they are platform::Place t_out_p = out_var_handle->place_;
// different, DataTransform should be applied. Currently, it has not been done if (platform::is_gpu_place(pre_in_value.place())) {
// yet. PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
auto &out_place = out_var_handle->place_; "Places of input and output must be all on GPU.");
PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(), } else {
"Currently, Places of input and output must be all on CPU " t_out_p = platform::CPUPlace();
"or all on GPU."); }
auto out_var = auto out_var =
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
...@@ -93,18 +94,18 @@ void GatherOpHandle::RunImpl() { ...@@ -93,18 +94,18 @@ void GatherOpHandle::RunImpl() {
DDim out_dim = pre_in_value.GetCompleteDims(); DDim out_dim = pre_in_value.GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows); out_dim[0] = static_cast<int64_t>(rows);
out_value->mutable_value()->Resize(out_dim).mutable_data( out_value->mutable_value()->Resize(out_dim).mutable_data(
out_place, pre_in_value.value().type()); t_out_p, pre_in_value.value().type());
Tensor *out_tensor = out_value->mutable_value(); Tensor *out_tensor = out_value->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_[out_place]; auto dev_ctx = dev_ctxes_[out_var_handle->place_];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, &dev_ctx, out_place] { RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
t_out_p] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0]; e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e); auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *dev_ctx, paddle::framework::TensorCopy(in_tensors[j], t_out_p, *dev_ctx, &sub_out);
&sub_out);
s = e; s = e;
} }
}); });
......
...@@ -53,6 +53,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -53,6 +53,7 @@ void ReduceOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx std::vector<platform::Place> in_places; // used to get dev_ctx
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
in_places.emplace_back(in_handle->place_); in_places.emplace_back(in_handle->place_);
...@@ -66,22 +67,23 @@ void ReduceOpHandle::RunImpl() { ...@@ -66,22 +67,23 @@ void ReduceOpHandle::RunImpl() {
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
// TODO(zcd): The Place of var_handle is determined at building SSA graph // NOTE: The tensors' Place of input and output must be all on GPU or all on
// stage, while the Place of var is determined at runtime. If they are // CPU.
// different, DataTransform should be applied. Currently, it has not been done auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place();
// yet. platform::Place t_out_p;
PADDLE_ENFORCE_EQ( if (platform::is_gpu_place(in_p)) {
VariableVisitor::GetMutableTensor(pre_in_var).place().which(), PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_),
out_var_handle->place_.which(), "Places of input and output must be all on GPU.");
"Currently, Places of input and output must be all on CPU or all on " t_out_p = out_var_handle->place_;
"GPU."); } else {
t_out_p = platform::CPUPlace();
}
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var_handle->place_,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else { } else {
std::vector<const LoDTensor *> lod_tensors = std::vector<const LoDTensor *> lod_tensors =
......
...@@ -48,10 +48,6 @@ class SSAGraphBuilder { ...@@ -48,10 +48,6 @@ class SSAGraphBuilder {
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
static VarHandle *GetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
......
...@@ -62,19 +62,10 @@ struct VarHandle : public VarHandleBase { ...@@ -62,19 +62,10 @@ struct VarHandle : public VarHandleBase {
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four bool IsTheSameVar(const VarHandle& o) const {
// member variables(version_, scope_id_, name_, place_) must be equal. But
// sometimes judging whether the two var_handle is equal is actually to
// determine whether the two Variables that represented by var_handle is the
// same. And the same Variable may have many different var_handles, the
// version_ of these var_handles is different. So I don't take care of
// version_ temporarily when overloading equal.
bool operator==(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ && return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_; o.scope_idx_ == scope_idx_;
} }
bool operator!=(const VarHandle& o) const { return !this->operator==(o); }
}; };
// Dummy Variable. It is used to represent dependencies between operators // Dummy Variable. It is used to represent dependencies between operators
......
...@@ -88,7 +88,7 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ...@@ -88,7 +88,7 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
VisitVariable(src, &visitor); VisitVariable(src, &visitor);
} }
struct EnforceEqualShapeAndDTypeVisitor { struct EnforceShapeAndDTypeEQVisitor {
const Variable* trg_; const Variable* trg_;
void operator()(const LoDTensor& src) { void operator()(const LoDTensor& src) {
...@@ -130,7 +130,7 @@ struct EnforceEqualShapeAndDTypeVisitor { ...@@ -130,7 +130,7 @@ struct EnforceEqualShapeAndDTypeVisitor {
void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1, void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1,
const Variable& var2) { const Variable& var2) {
EnforceEqualShapeAndDTypeVisitor visitor{&var1}; EnforceShapeAndDTypeEQVisitor visitor{&var1};
VisitVariable(var2, &visitor); VisitVariable(var2, &visitor);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册