提交 d24ef931 编写于 作者: Y Yu Yang

Clean Code

上级 4abef501
...@@ -24,6 +24,9 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS ...@@ -24,6 +24,9 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
......
...@@ -13,102 +13,70 @@ ...@@ -13,102 +13,70 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
Tensor *GetTensorFromVar(Variable *in_var) {
if (in_var->IsType<LoDTensor>()) {
return in_var->GetMutable<LoDTensor>();
} else if (in_var->IsType<SelectedRows>()) {
return in_var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
return nullptr;
}
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
// the input and output may have dummy var. // the input and output may have dummy var.
std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_); VarHandle *in_var_handle;
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
{
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
"The number of input should be one.");
in_var_handle = in_var_handles[0];
}
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(), out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operationplatform::Place // Wait input done, this Wait is asynchronous operation platform::Place
// &in_place; // &in_place;
WaitEvents(out_var_handles, in_var_handle); WaitInputVarGenerated(*in_var_handle);
auto in_place = in_var_handle[0]->place_; auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
auto in_scope_idx = in_var_handle[0]->scope_idx_; ->FindVar(in_var_handle->name_);
auto in_var = PADDLE_ENFORCE_NOT_NULL(in_var);
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
Tensor *in_tensor = GetTensorFromVar(in_var);
for (auto *out : out_var_handles) { for (auto *out : out_var_handles) {
if (*out == *in_var_handle) {
continue;
}
auto &out_p = out->place_; auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
"Places must be all on CPU or all on CUDA."); "Places must be all on CPU or all on CUDA.");
if (in_var->IsType<framework::SelectedRows>()) { VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
auto &in_sr = in_var->Get<framework::SelectedRows>(); VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
auto out_sr = out_var->GetMutable<framework::SelectedRows>(); in_tensor.type());
if (&in_sr == out_sr) continue;
out_sr->set_height(in_sr.height());
out_sr->set_rows(in_sr.rows());
out_sr->mutable_value()->Resize(in_sr.value().dims());
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto in_lod = in_var->Get<framework::LoDTensor>();
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
if (&in_lod == out_lod) continue;
out_lod->set_lod(in_lod.lod());
out_lod->Resize(in_lod.dims());
out_lod->mutable_data(out_p, in_lod.type());
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
}
auto dev_ctx = dev_ctxes_[out_p]; auto dev_ctx = dev_ctxes_[out_p];
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
Tensor *out_tensor = GetTensorFromVar(out_var); paddle::framework::TensorCopy(
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctx), out_tensor); in_tensor, out_p, *(dev_ctx),
&VariableVisitor::GetMutableTensor(out_var));
}); });
} }
} }
void BroadcastOpHandle::WaitEvents( void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
const std::vector<VarHandle *> &out_var_handles, for (auto &pair : dev_ctxes_) {
const std::vector<VarHandle *> &in_var_handle) { in_var.generated_op_->Wait(pair.second);
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
}
std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
} }
return in_var_handle;
} }
std::string BroadcastOpHandle::Name() const { return "broadcast"; } std::string BroadcastOpHandle::Name() const { return "broadcast"; }
......
...@@ -42,11 +42,7 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -42,11 +42,7 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles( void WaitInputVarGenerated(const VarHandle &in_var);
const std::vector<VarHandleBase *> &inputs);
void WaitEvents(const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle);
}; };
} // namespace details } // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
template <typename ResultType, typename ElemType>
std::vector<ResultType*> DynamicCast(const std::vector<ElemType*>& container) {
static_assert(std::is_base_of<ElemType, ResultType>::value,
"ElementType must be a base class of ResultType");
std::vector<ResultType*> res;
for (auto* ptr : container) {
auto* derived = dynamic_cast<ResultType*>(ptr);
if (derived) {
res.emplace_back(derived);
}
}
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/gather_op_handle.h" #include "paddle/fluid/framework/details/gather_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,16 +26,22 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -24,16 +26,22 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
// the input and output may have dummy var. // the input and output may have dummy var.
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(), in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]); VarHandle *out_var_handle;
{
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
out_var_handle = out_var_handles.front();
}
auto in_0_handle = in_var_handles[0];
auto pre_in_var = auto pre_in_var =
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_); local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
auto pre_place = in_0_handle->place_; auto pre_place = in_0_handle->place_;
...@@ -41,11 +49,11 @@ void GatherOpHandle::RunImpl() { ...@@ -41,11 +49,11 @@ void GatherOpHandle::RunImpl() {
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(), PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same."); "The place of input and output should be the same.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitEvents(in_var_handles); WaitInputVarGenerated(in_var_handles);
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
...@@ -53,13 +61,12 @@ void GatherOpHandle::RunImpl() { ...@@ -53,13 +61,12 @@ void GatherOpHandle::RunImpl() {
auto &pre_in = pre_in_var->Get<framework::SelectedRows>(); auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
// gather the inputs // gather the inputs
for (auto *in : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto in_handle = static_cast<VarHandle *>(in);
auto in_p = in_handle->place_; auto in_p = in_handle->place_;
in_places.push_back(in_p); in_places.push_back(in_p);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA."); "Places must be all on CPU or all on CUDA.");
auto in_var = auto *in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>(); auto &in_sr = in_var->Get<framework::SelectedRows>();
...@@ -70,17 +77,16 @@ void GatherOpHandle::RunImpl() { ...@@ -70,17 +77,16 @@ void GatherOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent."); "The dims of inputs is not consistent.");
auto in_sr_rows = in_sr.rows(); auto &in_sr_rows = in_sr.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
in_tensors.emplace_back(in_sr.value()); in_tensors.emplace_back(in_sr.value());
} }
// write the output // write the output
auto &out_place = out_var_handles[0]->place_; auto &out_place = out_var_handle->place_;
auto out_scope_idx = out_var_handles[0]->scope_idx_; auto out_scope_idx = out_var_handle->scope_idx_;
auto out_var = auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);
auto out = out_var->GetMutable<framework::SelectedRows>(); auto out = out_var->GetMutable<framework::SelectedRows>();
out->set_height(pre_in.height()); out->set_height(pre_in.height());
...@@ -106,25 +112,15 @@ void GatherOpHandle::RunImpl() { ...@@ -106,25 +112,15 @@ void GatherOpHandle::RunImpl() {
}); });
} }
void GatherOpHandle::WaitEvents( void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) { const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) { for (auto *in : in_var_handles) {
if (in->generated_op_) { if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]); for (auto pair : dev_ctxes_) {
} in->generated_op_->Wait(pair.second);
} }
}
std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
} }
} }
return in_var_handles;
} }
std::string GatherOpHandle::Name() const { return "gather"; } std::string GatherOpHandle::Name() const { return "gather"; }
......
...@@ -42,10 +42,7 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -42,10 +42,7 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles( void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
const std::vector<VarHandleBase *> &);
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
}; };
} // namespace details } // namespace details
......
...@@ -53,6 +53,11 @@ struct VarHandle : public VarHandleBase { ...@@ -53,6 +53,11 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_; size_t scope_idx_;
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
bool operator==(const VarHandle &o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_;
}
}; };
// Dummy Variable. It is used to represent dependencies between operators // Dummy Variable. It is used to represent dependencies between operators
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace framework {
namespace details {
template <typename Func>
static void VisitVariable(Variable* var, Func func) {
if (var->IsType<LoDTensor>()) {
func(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
func(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
}
}
template <typename Func>
static void VisitVariable(const Variable& var, Func func) {
if (var.IsType<LoDTensor>()) {
func(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) {
func(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
}
}
struct TensorVisitor {
Tensor* result_{nullptr};
void operator()(LoDTensor* tensor) { result_ = tensor; }
void operator()(SelectedRows* selected_rows) {
result_ = selected_rows->mutable_value();
}
template <typename T>
void operator()() {
PADDLE_THROW("Not Support to get LoDTensor from %s", typeid(T).name());
}
};
Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
TensorVisitor vistor;
VisitVariable(var, vistor);
return *vistor.result_;
}
struct ShareDimsAndLoDVisitor {
Variable* trg_;
void operator()(const LoDTensor& val) {
auto* tensor = trg_->GetMutable<LoDTensor>();
tensor->set_layout(val.layout());
tensor->set_lod(val.lod());
tensor->Resize(val.dims());
}
void operator()(const SelectedRows& val) {
auto* selected_rows = trg_->GetMutable<SelectedRows>();
selected_rows->set_rows(val.rows());
selected_rows->set_height(val.height());
selected_rows->mutable_value()->Resize(val.value().dims());
}
template <typename T>
void operator()(const T&) {
PADDLE_ENFORCE("ShareDimsAndLoD is not supported by type %s",
typeid(T).name());
}
};
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
ShareDimsAndLoDVisitor visitor{trg};
VisitVariable(src, visitor);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
namespace details {
class VariableVisitor {
public:
static Tensor &GetMutableTensor(Variable *var);
static void ShareDimsAndLoD(const Variable &src, Variable *trg);
};
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册