未验证 提交 eb8e14c9 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10081 from chengduoZH/refine/gather_broadcast

Fix scope of gather and broadcast, and code clean
......@@ -8,27 +8,28 @@ cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope plac
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim dynload_cuda)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
else()
set(multi_devices_graph_builder_deps)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base scope ddim)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
endif()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
......
......@@ -44,9 +44,15 @@ void BroadcastOpHandle::RunImpl() {
// &in_place;
WaitInputVarGenerated(*in_var_handle);
auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
->FindVar(in_var_handle->name_);
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
for (auto *out : out_var_handles) {
......@@ -55,17 +61,16 @@ void BroadcastOpHandle::RunImpl() {
}
auto &out_p = out->place_;
auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_);
PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
"Places must be all on CPU or all on CUDA.");
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var)
.Resize(in_tensor.dims())
.mutable_data(out_p, in_tensor.type());
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
in_tensor.type());
auto dev_ctx = dev_ctxes_[out_p];
auto dev_ctx = dev_ctxes_.at(out_p);
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
paddle::framework::TensorCopy(
in_tensor, out_p, *(dev_ctx),
......
......@@ -30,6 +30,7 @@ const f::DDim kDims = {20, 20};
struct TestBroadcastOpHandle {
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_;
Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
......@@ -72,11 +73,17 @@ struct TestBroadcastOpHandle {
void InitBroadcastOp(size_t input_scope_idx) {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
local_scopes_[j]->Var("out");
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("out");
param_scopes_.emplace_back(&local_scope);
}
local_scopes_[input_scope_idx]->Var("input");
param_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
......@@ -105,7 +112,8 @@ struct TestBroadcastOpHandle {
}
void TestBroadcastLodTensor(size_t input_scope_idx) {
auto in_var = local_scopes_[input_scope_idx]->Var("input");
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
in_lod_tensor->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
......@@ -117,6 +125,7 @@ struct TestBroadcastOpHandle {
paddle::framework::TensorFromVector<float>(
send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor);
in_lod_tensor->set_lod(lod);
in_lod_tensor->Resize(kDims);
op_handle_->Run(false);
......@@ -124,7 +133,8 @@ struct TestBroadcastOpHandle {
p::CPUPlace cpu_place;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
auto out_var = local_scopes_[j]->Var("out");
auto out_var = param_scopes_[j]->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_tensor = out_var->Get<f::LoDTensor>();
PADDLE_ENFORCE_EQ(out_tensor.lod(), lod, "lod is not equal.");
......@@ -139,7 +149,8 @@ struct TestBroadcastOpHandle {
}
void TestBroadcastSelectedRows(size_t input_scope_idx) {
auto in_var = local_scopes_[input_scope_idx]->Var("input");
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
auto value = in_selected_rows->mutable_value();
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
......@@ -162,7 +173,8 @@ struct TestBroadcastOpHandle {
p::CPUPlace cpu_place;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
auto out_var = local_scopes_[j]->Var("out");
auto out_var = param_scopes_[j]->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto& out_select_rows = out_var->Get<f::SelectedRows>();
auto rt = out_select_rows.value();
......
......@@ -41,14 +41,19 @@ void GatherOpHandle::RunImpl() {
out_var_handle = out_var_handles.front();
}
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto in_0_handle = in_var_handles[0];
auto pre_in_var =
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
auto pre_place = in_0_handle->place_;
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows.");
auto pre_place = in_0_handle->place_;
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same.");
......@@ -67,7 +72,7 @@ void GatherOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
auto *in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
......@@ -86,7 +91,7 @@ void GatherOpHandle::RunImpl() {
// write the output
auto &out_place = out_var_handle->place_;
auto out_scope_idx = out_var_handle->scope_idx_;
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);
auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_);
auto out = out_var->GetMutable<framework::SelectedRows>();
out->set_height(pre_in.height());
......
......@@ -29,6 +29,7 @@ const f::DDim kDims = {20, 20};
struct TestGatherOpHandle {
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_;
Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
......@@ -71,9 +72,14 @@ struct TestGatherOpHandle {
void InitGatherOp(size_t input_scope_idx) {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
local_scopes_[j]->Var("out");
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("input");
param_scopes_.emplace_back(&local_scope);
}
local_scopes_[input_scope_idx]->Var("input");
param_scopes_[input_scope_idx]->Var("out");
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
// add input
......@@ -115,7 +121,8 @@ struct TestGatherOpHandle {
for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size();
++input_scope_idx) {
auto in_var = local_scopes_[input_scope_idx]->Var("input");
auto in_var = param_scopes_.at(input_scope_idx)->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
auto value = in_selected_rows->mutable_value();
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
......@@ -128,10 +135,11 @@ struct TestGatherOpHandle {
value->Resize(kDims);
}
auto out_var = local_scopes_[output_scope_idx]->Var("out");
auto out_var = param_scopes_.at(output_scope_idx)->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_selected_rows = out_var->GetMutable<f::SelectedRows>();
auto in_var = local_scopes_[output_scope_idx]->Var("input");
auto in_var = param_scopes_.at(output_scope_idx)->FindVar("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
out_selected_rows->mutable_value()->ShareDataWith(
......@@ -155,7 +163,8 @@ struct TestGatherOpHandle {
f::TensorCopy(rt, cpu_place, *(ctxs_[output_scope_idx]), &result_tensor);
float* ct = result_tensor.data<float>();
for (int64_t j = 0; j < f::product(kDims); ++j) {
for (int64_t j = 0;
j < f::product(kDims) * static_cast<int64_t>(gpu_list_.size()); ++j) {
ASSERT_NEAR(ct[j], send_vector[j % send_vector.size()], 1e-5);
}
}
......
......@@ -43,21 +43,21 @@ void NCCLAllReduceOpHandle::RunImpl() {
int dtype = -1;
size_t numel = 0;
std::vector<LoDTensor> lod_tensors;
std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &lod_tensor = local_scope.FindVar(var_name)->Get<LoDTensor>();
lod_tensors.emplace_back(lod_tensor);
lod_tensors.emplace_back(&lod_tensor);
}
if (platform::is_gpu_place(lod_tensors[0].place())) {
if (platform::is_gpu_place(lod_tensors[0]->place())) {
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto &lod_tensor = lod_tensors[i];
auto &lod_tensor = *lod_tensors[i];
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
......@@ -93,7 +93,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0].type()), func);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &scope =
......
......@@ -24,23 +24,23 @@ namespace framework {
namespace details {
struct ReduceLoDTensor {
const std::vector<LoDTensor> &src_tensors_;
const std::vector<const LoDTensor *> &src_tensors_;
LoDTensor &dst_tensor_;
ReduceLoDTensor(const std::vector<LoDTensor> &src, LoDTensor *dst)
ReduceLoDTensor(const std::vector<const LoDTensor *> &src, LoDTensor *dst)
: src_tensors_(src), dst_tensor_(*dst) {}
template <typename T>
void operator()() const {
PADDLE_ENFORCE(!src_tensors_.empty());
auto &t0 = src_tensors_[0];
auto &t0 = *src_tensors_[0];
PADDLE_ENFORCE_NE(t0.numel(), 0);
dst_tensor_.Resize(t0.dims());
T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace());
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
for (size_t i = 1; i < src_tensors_.size(); ++i) {
auto &t = src_tensors_[i];
auto &t = *src_tensors_[i];
PADDLE_ENFORCE_EQ(t.dims(), t0.dims());
PADDLE_ENFORCE_EQ(t.type(), t0.type());
std::transform(t.data<T>(), t.data<T>() + t.numel(), dst, dst,
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle {
namespace framework {
......@@ -21,85 +23,84 @@ namespace details {
void ReduceOpHandle::RunImpl() {
// the input and output may have dummy var.
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
// Wait input done, this Wait is asynchronous operation
WaitEvents(in_var_handles);
VarHandle *out_var_handle;
{
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
out_var_handle = out_var_handles.front();
}
// check in the same place
auto in_0_handle = in_var_handles[0];
auto pre_place = in_0_handle->place_;
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles);
auto pre_place = in_0_handle->place_;
std::vector<platform::Place> in_places;
auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var);
for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
in_places.emplace_back(in_p);
}
auto out_var = local_scopes_[out_var_handles[0]->scope_idx_]->FindVar(
out_var_handles[0]->name_);
auto in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto pre_in_var =
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
if (pre_in_var->IsType<framework::SelectedRows>()) {
auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
std::vector<const SelectedRows *> in_selected_rows;
auto in_tensor = VariableVisitor::GetMutableTensor(in_var);
PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(),
"The type of input is not consistent.");
}
for (auto *in_handle : in_var_handles) {
auto in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>();
auto out_var =
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
"The type of input is not consistent.");
if (pre_in_var->IsType<framework::SelectedRows>()) {
std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
in_selected_rows.emplace_back(&in_sr);
}
auto trg = out_var->GetMutable<framework::SelectedRows>();
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_,
out_var_handles[0]->place_, trg);
out_var_handle->place_,
out_var->GetMutable<framework::SelectedRows>());
} else {
auto pre_in = pre_in_var->Get<framework::LoDTensor>();
std::vector<LoDTensor> lod_tensors;
// can be refined
for (auto *in_handle : in_var_handles) {
auto in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(in_sr.type(), pre_in.type(),
"The type of input is not consistent.");
lod_tensors.emplace_back(in_sr);
}
auto trg = out_var->GetMutable<framework::LoDTensor>();
trg->Resize(pre_in.dims());
trg->mutable_data(out_var_handles[0]->place_, pre_in.type());
std::vector<const LoDTensor *> lod_tensors =
GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(pre_place)) {
ReduceLoDTensor func(lod_tensors, trg);
VisitDataType(ToDataType(lod_tensors[0].type()), func);
ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
} else if (paddle::platform::is_gpu_place(pre_place)) {
#ifdef PADDLE_WITH_CUDA
auto out_p = out_var_handles[0]->place_;
int root = boost::get<platform::CUDAPlace>(out_p).device;
auto pre_in = pre_in_var->Get<framework::LoDTensor>();
VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(
out_var_handle->place_, pre_in.type());
auto out_p = out_var_handle->place_;
int root = boost::get<platform::CUDAPlace>(out_p).device;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
for (size_t i = 0; i < var_scopes.size(); ++i) {
auto &p = in_places[i];
auto &lod_tensor = lod_tensors[i];
auto &lod_tensor = *lod_tensors[i];
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
......@@ -109,14 +110,16 @@ void ReduceOpHandle::RunImpl() {
void *buffer = const_cast<void *>(lod_tensor.data<void>());
void *recvbuffer = nullptr;
if (root == dev_id) {
recvbuffer = trg->mutable_data(out_var_handles[0]->place_);
recvbuffer =
out_var->GetMutable<framework::LoDTensor>()->mutable_data(
out_var_handle->place_);
}
int type = platform::ToNCCLDataType(lod_tensor.type());
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclReduce(
buffer, recvbuffer, static_cast<size_t>(lod_tensor.numel()),
platform::ToNCCLDataType(lod_tensor.type()), ncclSum, root, comm,
stream));
static_cast<ncclDataType_t>(type), ncclSum, root, comm, stream));
});
}
......@@ -135,26 +138,31 @@ void ReduceOpHandle::RunImpl() {
}
}
void ReduceOpHandle::WaitEvents(
const std::vector<VarHandle *> &in_var_handles) {
if (in_var_handles[0]->generated_op_) {
for (auto *in : in_var_handles) {
in_var_handles[0]->generated_op_->Wait(dev_ctxes_[in->place_]);
}
template <typename T>
std::vector<const T *> ReduceOpHandle::GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
const std::vector<const Scope *> &var_scopes) const {
std::vector<const T *> in_selected_rows;
for (auto *in_handle : in_var_handles) {
auto &in_sr = var_scopes.at(in_handle->scope_idx_)
->FindVar(in_handle->name_)
->Get<T>();
in_selected_rows.emplace_back(&in_sr);
}
return in_selected_rows;
}
std::vector<VarHandle *> ReduceOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
void ReduceOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
return in_var_handles;
}
std::string ReduceOpHandle::Name() const { return "reduce"; }
} // namespace details
} // namespace framework
......
......@@ -59,10 +59,13 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs);
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
template <typename T>
std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
const std::vector<const Scope *> &var_scopes) const;
};
} // namespace details
......
......@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......@@ -30,6 +29,7 @@ struct TestReduceOpHandle {
bool use_gpu_;
Scope g_scope_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
std::vector<p::Place> gpu_list_;
......@@ -83,12 +83,18 @@ struct TestReduceOpHandle {
}
}
void InitReduceOp(size_t input_scope_idx) {
void InitReduceOp(size_t out_scope_idx) {
// init scope
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
local_scopes_[j]->Var("out");
Scope &local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope *>() = &local_scope;
local_scope.Var("input");
param_scopes_.emplace_back(&local_scope);
}
local_scopes_[input_scope_idx]->Var("input");
param_scopes_[out_scope_idx]->Var("out");
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
......@@ -106,6 +112,7 @@ struct TestReduceOpHandle {
#endif
}
// init op handle
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) {
......@@ -126,7 +133,7 @@ struct TestReduceOpHandle {
// add output
auto *out_var_handle =
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
......@@ -148,7 +155,8 @@ struct TestReduceOpHandle {
for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size();
++input_scope_idx) {
auto in_var = local_scopes_[input_scope_idx]->Var("input");
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
auto value = in_selected_rows->mutable_value();
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
......@@ -161,10 +169,11 @@ struct TestReduceOpHandle {
value->Resize(kDims);
}
auto out_var = local_scopes_[output_scope_idx]->Var("out");
auto out_var = param_scopes_[output_scope_idx]->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_selected_rows = out_var->GetMutable<f::SelectedRows>();
auto in_var = local_scopes_[output_scope_idx]->Var("input");
auto in_var = param_scopes_[output_scope_idx]->FindVar("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
out_selected_rows->mutable_value()->ShareDataWith(
......@@ -202,7 +211,8 @@ struct TestReduceOpHandle {
for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size();
++input_scope_idx) {
auto in_var = local_scopes_[input_scope_idx]->Var("input");
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
in_lod_tensor->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
in_lod_tensor->set_lod(lod);
......@@ -211,10 +221,11 @@ struct TestReduceOpHandle {
send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor);
}
auto out_var = local_scopes_[output_scope_idx]->Var("out");
auto out_var = param_scopes_[output_scope_idx]->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_lodtensor = out_var->GetMutable<f::LoDTensor>();
auto in_var = local_scopes_[output_scope_idx]->Var("input");
auto in_var = param_scopes_[output_scope_idx]->FindVar("input");
auto in_lodtensor = in_var->Get<f::LoDTensor>();
out_lodtensor->ShareDataWith(in_lodtensor);
......@@ -239,34 +250,34 @@ struct TestReduceOpHandle {
TEST(ReduceTester, TestCPUReduceTestSelectedRows) {
TestReduceOpHandle test_op;
size_t input_scope_idx = 0;
size_t out_scope_idx = 0;
test_op.InitCtxOnGpu(false);
test_op.InitReduceOp(input_scope_idx);
test_op.TestReduceSelectedRows(input_scope_idx);
test_op.InitReduceOp(out_scope_idx);
test_op.TestReduceSelectedRows(out_scope_idx);
}
TEST(ReduceTester, TestCPUReduceTestLodTensor) {
TestReduceOpHandle test_op;
size_t input_scope_idx = 0;
size_t out_scope_idx = 0;
test_op.InitCtxOnGpu(false);
test_op.InitReduceOp(input_scope_idx);
test_op.TestReduceLodTensors(input_scope_idx);
test_op.InitReduceOp(out_scope_idx);
test_op.TestReduceLodTensors(out_scope_idx);
}
#ifdef PADDLE_WITH_CUDA
TEST(ReduceTester, TestGPUReduceTestSelectedRows) {
TestReduceOpHandle test_op;
size_t input_scope_idx = 0;
size_t out_scope_idx = 0;
test_op.InitCtxOnGpu(true);
test_op.InitReduceOp(input_scope_idx);
test_op.TestReduceSelectedRows(input_scope_idx);
test_op.InitReduceOp(out_scope_idx);
test_op.TestReduceSelectedRows(out_scope_idx);
}
TEST(ReduceTester, TestGPUReduceTestLodTensor) {
TestReduceOpHandle test_op;
size_t input_scope_idx = 0;
size_t out_scope_idx = 0;
test_op.InitCtxOnGpu(true);
test_op.InitReduceOp(input_scope_idx);
test_op.TestReduceLodTensors(input_scope_idx);
test_op.InitReduceOp(out_scope_idx);
test_op.TestReduceLodTensors(out_scope_idx);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册