“a11144113901ea3121a55f5fbb1c9446789376d3”上不存在“develop/doc/build_and_install/build_from_source_en.html”
未验证 提交 2d2c31a6 编写于 作者: W wanghuancoder 提交者: GitHub

Add FetchAsyncOpHandle, and use it in FastThreadedExecutor (#26643)

* optimized transformation form tensor to numpy, test=develop

* Modify fetch op handle, from memcpy Sync to memcpy Async, test=develop

* modify CUDAPinnedPlace to CPUPlace, test=develop

* modify CPUPlace to CUDAPinnedPlace, and set default inplace to false, test=develop

* revert fetch_op_handle, add fetch_async_op_handle, test=develop

* revert fetch_op_handle, add fetch_async_op_handle, test=develop

* fix error msg report, test=develop

* fix bug in cpuplace, test=develop

* fix bug in unmerge and tensorarray modle, test=develop

* fix bug, double copy gpu memory, test=develop

* fix chenweihang¡¯s review advice, test=develop
上级 52057484
...@@ -3,6 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context ...@@ -3,6 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_async_op_handle SRCS fetch_async_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(share_tensor_buffer_functor SRCS share_tensor_buffer_functor.cc DEPS framework_proto scope place operator op_registry) cc_library(share_tensor_buffer_functor SRCS share_tensor_buffer_functor.cc DEPS framework_proto scope place operator op_registry)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -98,7 +99,7 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu ...@@ -98,7 +99,7 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle ) # device_context reduce_op_handle )
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
cc_test(exception_holder_test SRCS exception_holder_test.cc ) cc_test(exception_holder_test SRCS exception_holder_test.cc )
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -120,6 +121,11 @@ FetchResultType FastThreadedSSAGraphExecutor::Run( ...@@ -120,6 +121,11 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
} }
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
for (auto &place : places_) {
fetch_ctxs_.Get(place)->Wait();
}
return fetches; return fetches;
} }
...@@ -162,8 +168,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -162,8 +168,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node = ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_, auto *op = new FetchAsyncOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged); &local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
...@@ -174,6 +180,14 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -174,6 +180,14 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var); op->AddInput(var);
} }
for (auto *var : vars) {
auto *op = var->GeneratedOp();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op) {
compute_op->SetLockAndRecordEventFree(false);
}
}
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep; (*op_deps)[op] = dep;
if (dep == 0) { if (dep == 0) {
...@@ -261,7 +275,7 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { ...@@ -261,7 +275,7 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) { if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchAsyncOpHandle *>(op)) {
traced_ops_.emplace_back(op); traced_ops_.emplace_back(op);
} }
} }
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
FetchAsyncOpHandle::FetchAsyncOpHandle(ir::Node *node, FetchResultType *data,
size_t offset,
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes,
bool return_merged)
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
return_merged_(return_merged) {}
FetchAsyncOpHandle::~FetchAsyncOpHandle() {}
void FetchAsyncOpHandle::RecordWaitEventOnCtx(
platform::DeviceContext *waited_ctx) {
PADDLE_THROW(platform::errors::PermissionDenied(
"No nodes need to wait FetchAsyncOp. Unexpceted Error."));
}
static void CheckTensorAttrs(const LoDTensor *tensor,
const proto::VarType::Type &type,
const DataLayout &layout, const DDim &dims,
const LoD &lod, const size_t offset) {
if (tensor->numel() && tensor->IsInitialized()) {
// step1: check type
PADDLE_ENFORCE_EQ(
type, tensor->type(),
platform::errors::InvalidArgument(
"The data type of fetched Tensors or the items of fetched "
"LoDTensorArray are different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
DataTypeToString(type), DataTypeToString(tensor->type()), offset));
// step2: check layout
PADDLE_ENFORCE_EQ(
layout, tensor->layout(),
platform::errors::InvalidArgument(
"The layout of fetched Tensors or the items of fetched "
"LoDTensorArray are different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
DataLayoutToString(layout), DataLayoutToString(tensor->layout()),
offset));
}
// step3: check dims
auto tensor_dims = tensor->dims();
PADDLE_ENFORCE_EQ(dims.size(), tensor_dims.size(),
platform::errors::InvalidArgument(
"The dimension sizes of fetched Tensors or "
"the items of fetched LoDTensorArray are "
"different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
dims, tensor_dims, offset));
for (int j = 1; j < dims.size(); j++) {
PADDLE_ENFORCE_EQ(dims[j], tensor_dims[j],
platform::errors::InvalidArgument(
"The dimensions of fetched Tensors or "
"the items of fetched LoDTensorArray are "
"different from each other on different "
"devices(%s vs %s). And the error is caused by the "
"%zu (th) fetched variable. Please set the "
"parameter `return_merged = False` when "
"you call the `Executor.run()` method.",
dims, tensor_dims, offset));
}
// step4: check lod
PADDLE_ENFORCE_EQ(
lod.size(), tensor->lod().size(),
platform::errors::InvalidArgument(
"The LoD information of fetched Tensors or the items of fetched "
"LoDTensorArray are different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
lod, tensor->lod(), offset));
}
static void TransData(const framework::Tensor *src_item,
framework::Tensor *dst_item,
const platform::DeviceContext &ctx) {
if (src_item->IsInitialized() && src_item->numel() > 0) {
if (platform::is_gpu_place(src_item->place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(*src_item, platform::CUDAPinnedPlace(), ctx, dst_item);
#endif
} else {
TensorCopy(*src_item, platform::CPUPlace(), dst_item);
}
}
}
void FetchAsyncOpHandle::FetchMergedLodTensor(
const std::vector<const LoDTensor *> &src_lodtensors,
LoDTensor *dst_lodtensor) {
// calc dst type,layout,dim,lod and calc check dim
proto::VarType::Type new_type = proto::VarType::FP32;
framework::DataLayout new_layout;
framework::DDim new_dim;
LoD new_lod = src_lodtensors[0]->lod();
framework::DDim check_dim;
for (auto *t : src_lodtensors) {
if (t->numel() && t->IsInitialized()) {
check_dim = t->dims();
new_type = t->type();
new_layout = t->layout();
break;
}
}
bool find_first_dims = false;
for (auto *t : src_lodtensors) {
if (t->numel() && t->IsInitialized()) {
if (!find_first_dims) {
new_dim = t->dims();
find_first_dims = true;
} else {
new_dim[0] += t->dims()[0];
}
}
}
// check src type,layout,dim,lod consistence
for (size_t i = 1; i < src_lodtensors.size(); ++i) {
CheckTensorAttrs(src_lodtensors[i], new_type, new_layout, check_dim,
new_lod, offset_);
}
// set dst tensor
dst_lodtensor->Resize(new_dim);
dst_lodtensor->set_layout(src_lodtensors[0]->layout());
dst_lodtensor->set_lod(src_lodtensors[0]->lod());
if (platform::is_gpu_place(src_lodtensors[0]->place())) {
dst_lodtensor->mutable_data(platform::CUDAPinnedPlace(),
src_lodtensors[0]->type());
} else {
dst_lodtensor->mutable_data(platform::CPUPlace(),
src_lodtensors[0]->type());
}
// slice and memcpy
int begin = 0;
for (auto *src : src_lodtensors) {
int end = begin + src->dims()[0];
if (end == begin) {
continue;
}
auto dst = dst_lodtensor->Slice(begin, end);
TransData(src, &dst, *dev_ctxes_[src->place()]);
begin = end;
}
}
void FetchAsyncOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
// get src vars
auto &scopes = *local_exec_scopes_;
std::vector<Variable *> src_vars;
src_vars.reserve(inputs_.size());
for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"Cannot find variable %s in execution scope.", var_handle->name()));
src_vars.emplace_back(var);
}
if (return_merged_) {
auto &val = BOOST_GET(FetchList, *data_);
if (src_vars[0]->IsType<LoDTensor>()) {
// to lodtensor type
std::vector<const LoDTensor *> src_lodtensors;
src_lodtensors.reserve(src_vars.size());
for (size_t i = 0; i < src_vars.size(); ++i) {
src_lodtensors.emplace_back(&src_vars[i]->Get<framework::LoDTensor>());
}
LoDTensor dst_lodtensor;
FetchMergedLodTensor(src_lodtensors, &dst_lodtensor);
val.at(offset_) = std::move(dst_lodtensor);
} else {
// to lodtensorarray type
std::vector<const LoDTensorArray *> src_lodtensor_arrays;
src_lodtensor_arrays.reserve(src_vars.size());
for (size_t i = 0; i < src_vars.size(); ++i) {
src_lodtensor_arrays.emplace_back(
&src_vars[i]->Get<framework::LoDTensorArray>());
}
LoDTensorArray dst_lodtensor_array;
dst_lodtensor_array.resize(src_lodtensor_arrays[0]->size());
for (size_t i = 0; i < dst_lodtensor_array.size(); ++i) {
std::vector<const LoDTensor *> src_lodtensors;
src_lodtensors.reserve(src_lodtensor_arrays.size());
for (size_t j = 0; j < src_lodtensor_arrays.size(); ++j) {
src_lodtensors.emplace_back(&(*src_lodtensor_arrays[j])[i]);
}
FetchMergedLodTensor(src_lodtensors, &dst_lodtensor_array[i]);
}
val.at(offset_) = std::move(dst_lodtensor_array);
}
} else {
auto &val = BOOST_GET(FetchUnmergedList, *data_);
auto &dst_tensors = val.at(offset_);
dst_tensors.reserve(src_vars.size());
for (size_t i = 0; i < src_vars.size(); ++i) {
if (src_vars[i]->IsType<LoDTensor>()) {
auto &t = src_vars[i]->Get<framework::LoDTensor>();
LoDTensor item;
TransData(&t, &item, *dev_ctxes_[t.place()]);
dst_tensors.emplace_back(std::move(item));
} else {
auto &t = src_vars[i]->Get<framework::LoDTensorArray>();
LoDTensorArray item;
item.resize(t.size());
for (size_t j = 0; j < t.size(); ++j) {
TransData(&t[j], &item[j], *dev_ctxes_[t[j].place()]);
}
dst_tensors.emplace_back(std::move(item));
}
}
}
}
bool FetchAsyncOpHandle::IsMultiDeviceTransfer() { return true; }
std::string FetchAsyncOpHandle::Name() const { return "FetchAsync"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FetchAsyncOpHandle : public OpHandleBase {
public:
FetchAsyncOpHandle(ir::Node *node, FetchResultType *data, size_t offset,
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes,
bool return_merged);
~FetchAsyncOpHandle();
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
std::string Name() const override;
bool IsMultiDeviceTransfer() override;
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return *local_scopes_; }
void FetchMergedLodTensor(
const std::vector<const LoDTensor *> &src_lodtensors,
LoDTensor *dst_lodtensor);
private:
FetchResultType *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_;
bool return_merged_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -36,7 +36,8 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FetchResultType *data, ...@@ -36,7 +36,8 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FetchResultType *data,
FetchOpHandle::~FetchOpHandle() {} FetchOpHandle::~FetchOpHandle() {}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW(platform::errors::PermissionDenied(
"No nodes need to wait FetchOp. Unexpceted Error."));
} }
static void CheckDims(const framework::DDim &tensor_dims, static void CheckDims(const framework::DDim &tensor_dims,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,9 +24,11 @@ void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) { ...@@ -23,9 +24,11 @@ void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_EQ(dynamic_cast<FetchOpHandle*>(op) != nullptr ||
dynamic_cast<FetchOpHandle*>(op), dynamic_cast<FetchAsyncOpHandle*>(op) != nullptr,
"The input ops of ClearFetchOp function should be FetchOpHandle."); true,
"The input ops of ClearFetchOp function should be "
"FetchOpHandle or FetchAsyncOpHandle.");
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册