diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4d8bd101258664f6cafd71784ae070e0cb8b9215..a3cc4d1721e20a72817606bd773129230a8154ce 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -3,6 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) +cc_library(fetch_async_op_handle SRCS fetch_async_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(share_tensor_buffer_functor SRCS share_tensor_buffer_functor.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) @@ -98,7 +99,7 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory # device_context reduce_op_handle ) cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc - DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) + DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) cc_test(exception_holder_test SRCS exception_holder_test.cc ) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index f5ec78f44b5ebb780cc569c24ccdca6336195961..e440dff2af6b5649d34f47c3b696edeb8a1ba0a2 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -18,7 +18,8 @@ #include #include #include -#include "paddle/fluid/framework/details/fetch_op_handle.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/fetch_async_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -120,6 +121,11 @@ FetchResultType FastThreadedSSAGraphExecutor::Run( } // Wait FetchOps. ClearFetchOp(graph_, &fetch_ops); + + for (auto &place : places_) { + fetch_ctxs_.Get(place)->Wait(); + } + return fetches; } @@ -162,8 +168,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ir::Node *fetch_node = graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); - auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_, - &local_exec_scopes_, return_merged); + auto *op = new FetchAsyncOpHandle(fetch_node, fetches, i, &local_scopes_, + &local_exec_scopes_, return_merged); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -174,6 +180,14 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } + for (auto *var : vars) { + auto *op = var->GeneratedOp(); + auto *compute_op = dynamic_cast(op); + if (compute_op) { + compute_op->SetLockAndRecordEventFree(false); + } + } + int dep = static_cast(op->NotReadyInputSize()); (*op_deps)[op] = dep; if (dep == 0) { @@ -261,7 +275,7 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { - if (strategy_.num_threads_ == 1 && !dynamic_cast(op)) { + if (strategy_.num_threads_ == 1 && !dynamic_cast(op)) { traced_ops_.emplace_back(op); } } diff --git a/paddle/fluid/framework/details/fetch_async_op_handle.cc b/paddle/fluid/framework/details/fetch_async_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..6aae523365ed50e78a78b318ac0990490c801eb3 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_async_op_handle.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fetch_async_op_handle.h" +#include +#include +#include +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +namespace details { + +FetchAsyncOpHandle::FetchAsyncOpHandle(ir::Node *node, FetchResultType *data, + size_t offset, + std::vector *local_scopes, + std::vector *local_exec_scopes, + bool return_merged) + : OpHandleBase(node), + data_(data), + offset_(offset), + local_scopes_(local_scopes), + local_exec_scopes_(local_exec_scopes), + return_merged_(return_merged) {} + +FetchAsyncOpHandle::~FetchAsyncOpHandle() {} + +void FetchAsyncOpHandle::RecordWaitEventOnCtx( + platform::DeviceContext *waited_ctx) { + PADDLE_THROW(platform::errors::PermissionDenied( + "No nodes need to wait FetchAsyncOp. Unexpceted Error.")); +} + +static void CheckTensorAttrs(const LoDTensor *tensor, + const proto::VarType::Type &type, + const DataLayout &layout, const DDim &dims, + const LoD &lod, const size_t offset) { + if (tensor->numel() && tensor->IsInitialized()) { + // step1: check type + PADDLE_ENFORCE_EQ( + type, tensor->type(), + platform::errors::InvalidArgument( + "The data type of fetched Tensors or the items of fetched " + "LoDTensorArray are different from each other on different " + "devices(%s vs %s). And the error is caused by the %zu " + "(th) fetched variable. Please set the " + "parameter `return_merged = False` when you " + "call the `Executor.run()` method.", + DataTypeToString(type), DataTypeToString(tensor->type()), offset)); + + // step2: check layout + PADDLE_ENFORCE_EQ( + layout, tensor->layout(), + platform::errors::InvalidArgument( + "The layout of fetched Tensors or the items of fetched " + "LoDTensorArray are different from each other on different " + "devices(%s vs %s). And the error is caused by the %zu " + "(th) fetched variable. Please set the " + "parameter `return_merged = False` when you " + "call the `Executor.run()` method.", + DataLayoutToString(layout), DataLayoutToString(tensor->layout()), + offset)); + } + + // step3: check dims + auto tensor_dims = tensor->dims(); + PADDLE_ENFORCE_EQ(dims.size(), tensor_dims.size(), + platform::errors::InvalidArgument( + "The dimension sizes of fetched Tensors or " + "the items of fetched LoDTensorArray are " + "different from each other on different " + "devices(%s vs %s). And the error is caused by the %zu " + "(th) fetched variable. Please set the " + "parameter `return_merged = False` when you " + "call the `Executor.run()` method.", + dims, tensor_dims, offset)); + for (int j = 1; j < dims.size(); j++) { + PADDLE_ENFORCE_EQ(dims[j], tensor_dims[j], + platform::errors::InvalidArgument( + "The dimensions of fetched Tensors or " + "the items of fetched LoDTensorArray are " + "different from each other on different " + "devices(%s vs %s). And the error is caused by the " + "%zu (th) fetched variable. Please set the " + "parameter `return_merged = False` when " + "you call the `Executor.run()` method.", + dims, tensor_dims, offset)); + } + + // step4: check lod + PADDLE_ENFORCE_EQ( + lod.size(), tensor->lod().size(), + platform::errors::InvalidArgument( + "The LoD information of fetched Tensors or the items of fetched " + "LoDTensorArray are different from each other on different " + "devices(%s vs %s). And the error is caused by the %zu " + "(th) fetched variable. Please set the " + "parameter `return_merged = False` when you " + "call the `Executor.run()` method.", + lod, tensor->lod(), offset)); +} + +static void TransData(const framework::Tensor *src_item, + framework::Tensor *dst_item, + const platform::DeviceContext &ctx) { + if (src_item->IsInitialized() && src_item->numel() > 0) { + if (platform::is_gpu_place(src_item->place())) { +#ifdef PADDLE_WITH_CUDA + TensorCopy(*src_item, platform::CUDAPinnedPlace(), ctx, dst_item); +#endif + } else { + TensorCopy(*src_item, platform::CPUPlace(), dst_item); + } + } +} + +void FetchAsyncOpHandle::FetchMergedLodTensor( + const std::vector &src_lodtensors, + LoDTensor *dst_lodtensor) { + // calc dst type,layout,dim,lod and calc check dim + proto::VarType::Type new_type = proto::VarType::FP32; + framework::DataLayout new_layout; + framework::DDim new_dim; + LoD new_lod = src_lodtensors[0]->lod(); + + framework::DDim check_dim; + + for (auto *t : src_lodtensors) { + if (t->numel() && t->IsInitialized()) { + check_dim = t->dims(); + new_type = t->type(); + new_layout = t->layout(); + break; + } + } + + bool find_first_dims = false; + for (auto *t : src_lodtensors) { + if (t->numel() && t->IsInitialized()) { + if (!find_first_dims) { + new_dim = t->dims(); + find_first_dims = true; + } else { + new_dim[0] += t->dims()[0]; + } + } + } + + // check src type,layout,dim,lod consistence + for (size_t i = 1; i < src_lodtensors.size(); ++i) { + CheckTensorAttrs(src_lodtensors[i], new_type, new_layout, check_dim, + new_lod, offset_); + } + + // set dst tensor + dst_lodtensor->Resize(new_dim); + dst_lodtensor->set_layout(src_lodtensors[0]->layout()); + dst_lodtensor->set_lod(src_lodtensors[0]->lod()); + if (platform::is_gpu_place(src_lodtensors[0]->place())) { + dst_lodtensor->mutable_data(platform::CUDAPinnedPlace(), + src_lodtensors[0]->type()); + } else { + dst_lodtensor->mutable_data(platform::CPUPlace(), + src_lodtensors[0]->type()); + } + + // slice and memcpy + int begin = 0; + for (auto *src : src_lodtensors) { + int end = begin + src->dims()[0]; + if (end == begin) { + continue; + } + auto dst = dst_lodtensor->Slice(begin, end); + TransData(src, &dst, *dev_ctxes_[src->place()]); + begin = end; + } +} + +void FetchAsyncOpHandle::RunImpl() { + platform::RecordEvent record_event(Name()); + WaitInputVarGenerated(); + + // get src vars + auto &scopes = *local_exec_scopes_; + std::vector src_vars; + src_vars.reserve(inputs_.size()); + for (size_t i = 0; i < inputs_.size(); ++i) { + auto *var_handle = static_cast(inputs_[i]); + auto &scope = scopes.at(var_handle->scope_idx()); + auto *var = scope->FindVar(var_handle->name()); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound( + "Cannot find variable %s in execution scope.", var_handle->name())); + src_vars.emplace_back(var); + } + + if (return_merged_) { + auto &val = BOOST_GET(FetchList, *data_); + if (src_vars[0]->IsType()) { + // to lodtensor type + std::vector src_lodtensors; + src_lodtensors.reserve(src_vars.size()); + for (size_t i = 0; i < src_vars.size(); ++i) { + src_lodtensors.emplace_back(&src_vars[i]->Get()); + } + + LoDTensor dst_lodtensor; + FetchMergedLodTensor(src_lodtensors, &dst_lodtensor); + val.at(offset_) = std::move(dst_lodtensor); + } else { + // to lodtensorarray type + std::vector src_lodtensor_arrays; + src_lodtensor_arrays.reserve(src_vars.size()); + for (size_t i = 0; i < src_vars.size(); ++i) { + src_lodtensor_arrays.emplace_back( + &src_vars[i]->Get()); + } + + LoDTensorArray dst_lodtensor_array; + dst_lodtensor_array.resize(src_lodtensor_arrays[0]->size()); + + for (size_t i = 0; i < dst_lodtensor_array.size(); ++i) { + std::vector src_lodtensors; + src_lodtensors.reserve(src_lodtensor_arrays.size()); + for (size_t j = 0; j < src_lodtensor_arrays.size(); ++j) { + src_lodtensors.emplace_back(&(*src_lodtensor_arrays[j])[i]); + } + FetchMergedLodTensor(src_lodtensors, &dst_lodtensor_array[i]); + } + val.at(offset_) = std::move(dst_lodtensor_array); + } + } else { + auto &val = BOOST_GET(FetchUnmergedList, *data_); + auto &dst_tensors = val.at(offset_); + dst_tensors.reserve(src_vars.size()); + + for (size_t i = 0; i < src_vars.size(); ++i) { + if (src_vars[i]->IsType()) { + auto &t = src_vars[i]->Get(); + LoDTensor item; + TransData(&t, &item, *dev_ctxes_[t.place()]); + dst_tensors.emplace_back(std::move(item)); + } else { + auto &t = src_vars[i]->Get(); + LoDTensorArray item; + item.resize(t.size()); + for (size_t j = 0; j < t.size(); ++j) { + TransData(&t[j], &item[j], *dev_ctxes_[t[j].place()]); + } + dst_tensors.emplace_back(std::move(item)); + } + } + } +} + +bool FetchAsyncOpHandle::IsMultiDeviceTransfer() { return true; } + +std::string FetchAsyncOpHandle::Name() const { return "FetchAsync"; } + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_async_op_handle.h b/paddle/fluid/framework/details/fetch_async_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..691a3286c270badad938610811cc6e73d63c2c04 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_async_op_handle.h @@ -0,0 +1,63 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +namespace details { + +struct FetchAsyncOpHandle : public OpHandleBase { + public: + FetchAsyncOpHandle(ir::Node *node, FetchResultType *data, size_t offset, + std::vector *local_scopes, + std::vector *local_exec_scopes, + bool return_merged); + + ~FetchAsyncOpHandle(); + + void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override; + + std::string Name() const override; + + bool IsMultiDeviceTransfer() override; + + protected: + void RunImpl() override; + + std::vector GetLocalScopes() override { return *local_scopes_; } + + void FetchMergedLodTensor( + const std::vector &src_lodtensors, + LoDTensor *dst_lodtensor); + + private: + FetchResultType *data_; + size_t offset_; + std::vector *local_scopes_; + std::vector *local_exec_scopes_; + bool return_merged_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 5574a55e18c6d9806cb878dc69ec597f81da97d8..ae69960ef78c3e35143c66226133bd0dceac8b79 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -36,7 +36,8 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FetchResultType *data, FetchOpHandle::~FetchOpHandle() {} void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { - PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); + PADDLE_THROW(platform::errors::PermissionDenied( + "No nodes need to wait FetchOp. Unexpceted Error.")); } static void CheckDims(const framework::DDim &tensor_dims, diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc index 4f1e44ca26cb65468da6eded74653f34dbf00336..71123f708e3ca149d9fd634f55652cede5a57b50 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/ssa_graph_executor.h" +#include "paddle/fluid/framework/details/fetch_async_op_handle.h" namespace paddle { namespace framework { @@ -23,9 +24,11 @@ void ClearFetchOp(ir::Graph* graph, std::vector* fetch_ops) { if (fetch_ops->empty()) return; for (auto& op : *fetch_ops) { - PADDLE_ENFORCE_NOT_NULL( - dynamic_cast(op), - "The input ops of ClearFetchOp function should be FetchOpHandle."); + PADDLE_ENFORCE_EQ(dynamic_cast(op) != nullptr || + dynamic_cast(op) != nullptr, + true, + "The input ops of ClearFetchOp function should be " + "FetchOpHandle or FetchAsyncOpHandle."); for (auto& out_var : op->Node()->outputs) { graph->RemoveNode(out_var); }