未验证 提交 23a21c86 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #9922 from chengduoZH/feature/refine_gather_reduce

Refine gather and broadcast
...@@ -21,8 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor ...@@ -21,8 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
......
...@@ -13,95 +13,72 @@ ...@@ -13,95 +13,72 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
Tensor *GetTensorFromVar(Variable *in_var) {
if (in_var->IsType<LoDTensor>()) {
return in_var->GetMutable<LoDTensor>();
} else if (in_var->IsType<SelectedRows>()) {
return in_var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
return nullptr;
}
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
// the input may have dummy var. // the input and output may have dummy var.
std::vector<VarHandle *> in_var_handle; VarHandle *in_var_handle;
for (auto *in : inputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");
// the output may have dummy var. {
std::vector<VarHandle *> out_var_handles; auto in_var_handles = DynamicCast<VarHandle>(inputs_);
for (auto *out : outputs_) { PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
auto *out_handle = dynamic_cast<VarHandle *>(out); "The number of input should be one.");
if (out_handle) { in_var_handle = in_var_handles[0];
out_var_handles.push_back(out_handle);
}
} }
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(), out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation platform::Place
auto &in_place = in_var_handle[0]->place_; // &in_place;
if (in_var_handle[0]->generated_op_) { WaitInputVarGenerated(*in_var_handle);
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
// auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
auto in_scope_idx = in_var_handle[0]->scope_idx_; ->FindVar(in_var_handle->name_);
auto in_var = PADDLE_ENFORCE_NOT_NULL(in_var);
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
Tensor *in_tensor = GetTensorFromVar(in_var);
for (auto *out : out_var_handles) { for (auto *out : out_var_handles) {
if (*out == *in_var_handle) {
continue;
}
auto &out_p = out->place_; auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
"Places must be all on CPU or all on CUDA."); "Places must be all on CPU or all on CUDA.");
if (in_var->IsType<framework::SelectedRows>()) { VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
auto &in_sr = in_var->Get<framework::SelectedRows>(); VariableVisitor::GetMutableTensor(out_var)
auto out_sr = out_var->GetMutable<framework::SelectedRows>(); .Resize(in_tensor.dims())
if (&in_sr == out_sr) continue; .mutable_data(out_p, in_tensor.type());
out_sr->set_height(in_sr.height());
out_sr->set_rows(in_sr.rows()); auto dev_ctx = dev_ctxes_[out_p];
out_sr->mutable_value()->Resize(in_sr.value().dims()); RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type()); paddle::framework::TensorCopy(
} else if (in_var->IsType<framework::LoDTensor>()) { in_tensor, out_p, *(dev_ctx),
auto in_lod = in_var->Get<framework::LoDTensor>(); &VariableVisitor::GetMutableTensor(out_var));
auto out_lod = out_var->GetMutable<framework::LoDTensor>(); });
if (&in_lod == out_lod) continue;
out_lod->set_lod(in_lod.lod());
out_lod->Resize(in_lod.dims());
out_lod->mutable_data(out_p, in_lod.type());
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
} }
}
Tensor *out_tensor = GetTensorFromVar(out_var); void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), if (in_var.generated_op_) {
out_tensor); for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
}
} }
} }
......
...@@ -39,12 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -39,12 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
template <typename ResultType, typename ElemType>
std::vector<ResultType*> DynamicCast(const std::vector<ElemType*>& container) {
static_assert(std::is_base_of<ElemType, ResultType>::value,
"ElementType must be a base class of ResultType");
std::vector<ResultType*> res;
for (auto* ptr : container) {
auto* derived = dynamic_cast<ResultType*>(ptr);
if (derived) {
res.emplace_back(derived);
}
}
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/gather_op_handle.h" #include "paddle/fluid/framework/details/gather_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,30 +25,23 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -23,30 +25,23 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
// the input may have dummy var. // the input and output may have dummy var.
std::vector<VarHandle *> in_var_handles; auto in_var_handles = DynamicCast<VarHandle>(inputs_);
for (auto *in : inputs_) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(), in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// the output may have dummy var. VarHandle *out_var_handle;
std::vector<VarHandle *> out_var_handles; {
for (auto *out : outputs_) { auto out_var_handles = DynamicCast<VarHandle>(outputs_);
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one."); "The number of output should be one.");
out_var_handle = out_var_handles.front();
}
auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]); auto in_0_handle = in_var_handles[0];
auto pre_in_var = auto pre_in_var =
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_); local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
auto pre_place = in_0_handle->place_; auto pre_place = in_0_handle->place_;
...@@ -54,15 +49,11 @@ void GatherOpHandle::RunImpl() { ...@@ -54,15 +49,11 @@ void GatherOpHandle::RunImpl() {
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(), PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same."); "The place of input and output should be the same.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
for (auto *in : in_var_handles) { WaitInputVarGenerated(in_var_handles);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
...@@ -70,13 +61,12 @@ void GatherOpHandle::RunImpl() { ...@@ -70,13 +61,12 @@ void GatherOpHandle::RunImpl() {
auto &pre_in = pre_in_var->Get<framework::SelectedRows>(); auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
// gather the inputs // gather the inputs
for (auto *in : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto in_handle = static_cast<VarHandle *>(in);
auto in_p = in_handle->place_; auto in_p = in_handle->place_;
in_places.push_back(in_p); in_places.push_back(in_p);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA."); "Places must be all on CPU or all on CUDA.");
auto in_var = auto *in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>(); auto &in_sr = in_var->Get<framework::SelectedRows>();
...@@ -84,20 +74,19 @@ void GatherOpHandle::RunImpl() { ...@@ -84,20 +74,19 @@ void GatherOpHandle::RunImpl() {
"The type of input is not consistent."); "The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(), PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent."); "The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), , PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent."); "The dims of inputs is not consistent.");
auto in_sr_rows = in_sr.rows(); auto &in_sr_rows = in_sr.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
in_tensors.emplace_back(in_sr.value()); in_tensors.emplace_back(in_sr.value());
} }
// write the output // write the output
auto &out_place = out_var_handles[0]->place_; auto &out_place = out_var_handle->place_;
auto out_scope_idx = out_var_handles[0]->scope_idx_; auto out_scope_idx = out_var_handle->scope_idx_;
auto out_var = auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);
auto out = out_var->GetMutable<framework::SelectedRows>(); auto out = out_var->GetMutable<framework::SelectedRows>();
out->set_height(pre_in.height()); out->set_height(pre_in.height());
...@@ -110,14 +99,28 @@ void GatherOpHandle::RunImpl() { ...@@ -110,14 +99,28 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out->mutable_value(); Tensor *out_tensor = out->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0]; e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e); auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
*(dev_ctxes_[in_places[j]]), &sub_out); &sub_out);
s = e; s = e;
} }
});
}
void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
} }
std::string GatherOpHandle::Name() const { return "gather"; } std::string GatherOpHandle::Name() const { return "gather"; }
......
...@@ -39,6 +39,7 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -39,6 +39,7 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
......
...@@ -61,6 +61,11 @@ struct VarHandle : public VarHandleBase { ...@@ -61,6 +61,11 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_; size_t scope_idx_;
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
bool operator==(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_;
}
}; };
// Dummy Variable. It is used to represent dependencies between operators // Dummy Variable. It is used to represent dependencies between operators
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace framework {
namespace details {
template <typename Func>
static void VisitVariable(Variable* var, Func* func) {
if (var->IsType<LoDTensor>()) {
(*func)(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
(*func)(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
}
}
template <typename Func>
static void VisitVariable(const Variable& var, Func* func) {
if (var.IsType<LoDTensor>()) {
(*func)(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) {
(*func)(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
}
}
struct TensorVisitor {
Tensor* result_{nullptr};
void operator()(LoDTensor* tensor) { result_ = tensor; }
void operator()(SelectedRows* selected_rows) {
result_ = selected_rows->mutable_value();
}
template <typename T>
void operator()() {
PADDLE_THROW("Not Support to get LoDTensor from %s", typeid(T).name());
}
};
Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
TensorVisitor vistor;
VisitVariable(var, &vistor);
return *vistor.result_;
}
struct ShareDimsAndLoDVisitor {
Variable* trg_;
void operator()(const LoDTensor& val) {
auto* tensor = trg_->GetMutable<LoDTensor>();
tensor->set_layout(val.layout());
tensor->set_lod(val.lod());
tensor->Resize(val.dims());
}
void operator()(const SelectedRows& val) {
auto* selected_rows = trg_->GetMutable<SelectedRows>();
selected_rows->set_rows(val.rows());
selected_rows->set_height(val.height());
selected_rows->mutable_value()->Resize(val.value().dims());
}
template <typename T>
void operator()(const T&) {
PADDLE_ENFORCE("ShareDimsAndLoD is not supported by type %s",
typeid(T).name());
}
};
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
ShareDimsAndLoDVisitor visitor{trg};
VisitVariable(src, &visitor);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
namespace details {
class VariableVisitor {
public:
static Tensor &GetMutableTensor(Variable *var);
static void ShareDimsAndLoD(const Variable &src, Variable *trg);
};
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册