未验证 提交 76d2fd1d 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Compatible runtime performance optimization (#36946)

* resolve conflit with develop

* cache kernel context in tracer for perf up

* replace densetensor when build kernel context

* fix detail compile error

* append impl to static mode

* fix conflit error

* clear attrs after run kernel

* fix coverage failed

* fix cycle compile error

* remove multi-in&out adapt code

* remove tensor meta utils

* clear data when throw exception
上级 ad44a40c
...@@ -1131,7 +1131,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1131,7 +1131,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// phase // phase
if (FLAGS_run_pten_kernel && if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
ChoosePtenKernel(exe_ctx); ChoosePtenKernel(exe_ctx);
} }
run_pten_kernel_ = pt_kernel_->IsValid(); run_pten_kernel_ = pt_kernel_->IsValid();
...@@ -1178,8 +1178,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1178,8 +1178,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute", platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (run_pten_kernel_) { if (run_pten_kernel_) {
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx); if (pt_kernel_context_ == nullptr) {
(*pt_kernel_)(&op_kernel_ctx); pt_kernel_context_.reset(new pten::KernelContext());
}
BuildPtenKernelContext(*runtime_ctx, dev_ctx);
(*pt_kernel_)(pt_kernel_context_.get());
pt_kernel_context_->ClearData();
} else { } else {
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
...@@ -1765,8 +1769,8 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( ...@@ -1765,8 +1769,8 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
return KernelSignatureMap::Instance().Get(Type()); return KernelSignatureMap::Instance().Get(Type());
} }
pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const { const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
// TODO(chenweihang): now only work for very simple case, // TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later: // many cases need to be deal with later:
// 1. the input and output are not tensor // 1. the input and output are not tensor
...@@ -1774,7 +1778,7 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( ...@@ -1774,7 +1778,7 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
// 3. needless attributes remove // 3. needless attributes remove
// 4. use pt Tensor directly // 4. use pt Tensor directly
// 5. kernel input is not DenseTensor // 5. kernel input is not DenseTensor
pten::KernelContext op_kernel_ctx(dev_ctx); pt_kernel_context_->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature_->args); auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args); auto& attr_names = std::get<1>(pt_kernel_signature_->args);
...@@ -1803,30 +1807,53 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( ...@@ -1803,30 +1807,53 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
attr_names.size(), attr_defs.size())); attr_names.size(), attr_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", " auto& ins_vector = ctx.inputs.at(input_names[i]);
<< in_def.layout; if (pt_kernel_context_->InputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
auto ins_vector = ctx.inputs.at(input_names[i]); for (auto* var : ins_vector) {
tmp_inputs.emplace_back(
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs; experimental::MakePtenTensorBaseFromVar(*var, in_def));
for (auto var : ins_vector) { }
tmp_inputs.emplace_back( pt_kernel_context_->EmplaceBackInputs(std::move(tmp_inputs));
experimental::MakePtenTensorBaseFromVar(*var, in_def)); } else {
size_t input_size = pt_kernel_context_->InputsSize();
for (size_t j = 0; j < ins_vector.size(); ++j) {
if (input_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
*ins_vector[j], in_def,
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-input case later
}
pt_kernel_context_->MutableInputRangeAt(i) =
std::make_pair(i, i + ins_vector.size());
} }
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
auto out_def = output_defs.at(i); auto& out_def = output_defs.at(i);
auto outs_vector = ctx.outputs.at(output_names[i]); auto& outs_vector = ctx.outputs.at(output_names[i]);
if (pt_kernel_context_->OutputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs; paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) { for (auto* var : outs_vector) {
tmp_outputs.emplace_back( tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def)); experimental::MakePtenTensorBaseFromVar(var, out_def));
}
pt_kernel_context_->EmplaceBackOutputs(std::move(tmp_outputs));
} else {
size_t output_size = pt_kernel_context_->OutputsSize();
for (size_t j = 0; j < outs_vector.size(); ++j) {
if (output_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[j], out_def,
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-output case later
}
pt_kernel_context_->MutableOutputRangeAt(i) =
std::make_pair(i, i + outs_vector.size());
} }
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
...@@ -1836,11 +1863,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( ...@@ -1836,11 +1863,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs // attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr( pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
op_kernel_ctx.EmplaceBackAttr( pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -1851,11 +1878,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( ...@@ -1851,11 +1878,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "unsupported cast op attribute `%s` when construct "
...@@ -1864,8 +1891,6 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( ...@@ -1864,8 +1891,6 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
} }
} }
} }
return op_kernel_ctx;
} }
} // namespace framework } // namespace framework
......
...@@ -586,8 +586,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -586,8 +586,8 @@ class OperatorWithKernel : public OperatorBase {
/* member functions for adapting to pten lib */ /* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const; void ChoosePtenKernel(const ExecutionContext& ctx) const;
pten::KernelContext BuildPtenKernelContext( void BuildPtenKernelContext(const RuntimeContext& ctx,
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const; platform::DeviceContext* dev_ctx) const;
protected: protected:
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
...@@ -605,6 +605,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -605,6 +605,9 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_pten_kernel_ = false; mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_; mutable std::unique_ptr<pten::Kernel> pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
mutable std::unique_ptr<pten::KernelContext> pt_kernel_context_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags) cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
IF(WITH_XPU) IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils) cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
ELSE() ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
ENDIF() ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit) add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer ) cc_library(amp SRCS amp_auto_cast.cc DEPS layer )
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator) cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator) cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags) cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
......
...@@ -356,6 +356,8 @@ void VarBase::BumpInplaceVersion() { ...@@ -356,6 +356,8 @@ void VarBase::BumpInplaceVersion() {
MutableVar()->BumpInplaceVersion(); MutableVar()->BumpInplaceVersion();
} }
pten::KernelContext OpBase::pt_kernel_context_;
void OpBase::SetType(const std::string& type) { void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
} }
...@@ -371,7 +373,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -371,7 +373,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place,
pten::KernelContext* pt_kernel_context) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op); auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied( op_kernel, platform::errors::PermissionDenied(
...@@ -412,8 +415,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -412,8 +415,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly * after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention. * overwritten in the previous dynamic graph implemention.
*/ */
auto prepared_op = auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs,
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs); default_attrs, pt_kernel_context);
auto tmp_ins_ptr = auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type()); PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
...@@ -441,7 +444,8 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -441,7 +444,8 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place); OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
} }
void OpBase::Run(const framework::OperatorBase& op, void OpBase::Run(const framework::OperatorBase& op,
...@@ -450,7 +454,8 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -450,7 +454,8 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place); OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
} }
void ClearNoNeedBufferInputs(OpBase* op) { void ClearNoNeedBufferInputs(OpBase* op) {
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/pten/include/core.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/pten/include/core.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -183,6 +184,8 @@ class OpBase { ...@@ -183,6 +184,8 @@ class OpBase {
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; }
private: private:
static const std::string& UnknownOpType() { static const std::string& UnknownOpType() {
static std::string kUnknownOpType{"unknown"}; static std::string kUnknownOpType{"unknown"};
...@@ -197,6 +200,9 @@ class OpBase { ...@@ -197,6 +200,9 @@ class OpBase {
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_; platform::Place place_;
size_t id_{-1UL}; size_t id_{-1UL};
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static pten::KernelContext pt_kernel_context_;
}; };
class GradOpNode { class GradOpNode {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
...@@ -112,6 +113,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -112,6 +113,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, const framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel, const pten::Kernel& pt_kernel,
pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), : op_(op),
ctx_(ctx), ctx_(ctx),
...@@ -120,7 +122,8 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -120,7 +122,8 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
run_pten_kernel_(true), run_pten_kernel_(true),
pt_kernel_signature_(kernel_signature), pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel) {} pt_kernel_(pt_kernel),
pt_kernel_context_(pt_kernel_context) {}
template <typename VarType> template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
...@@ -128,7 +131,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -128,7 +131,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -171,7 +175,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -171,7 +175,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// TODO(chenweihang): using CPUKernel when miss device kernel case // TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx); pt_kernel, pt_kernel_context, dev_ctx);
} else { } else {
VLOG(1) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name VLOG(1) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
...@@ -230,8 +234,10 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, ...@@ -230,8 +234,10 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs,
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs); pten::KernelContext* pt_kernel_context) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
pt_kernel_context);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
...@@ -239,18 +245,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, ...@@ -239,18 +245,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs, return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs); default_attrs, pt_kernel_context);
} }
template <typename VarType> template <typename VarType>
static pten::KernelContext BuildDygraphPtenKernelContext( static void BuildDygraphPtenKernelContext(
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, const NameVarMap<VarType>& ins, const pten::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::DeviceContext& dev_ctx) { platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) {
// TODO(chenweihang): now only work for very simple case, // TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later: // many cases need to be deal with later:
// 1. the input and output are not tensor // 1. the input and output are not tensor
...@@ -258,7 +265,7 @@ static pten::KernelContext BuildDygraphPtenKernelContext( ...@@ -258,7 +265,7 @@ static pten::KernelContext BuildDygraphPtenKernelContext(
// 3. needless attributes remove // 3. needless attributes remove
// 4. use pt Tensor directly // 4. use pt Tensor directly
// 5. kernel input is not DenseTensor // 5. kernel input is not DenseTensor
pten::KernelContext op_kernel_ctx(dev_ctx); kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args); auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& attr_names = std::get<1>(pt_kernel_signature.args); auto& attr_names = std::get<1>(pt_kernel_signature.args);
...@@ -289,27 +296,53 @@ static pten::KernelContext BuildDygraphPtenKernelContext( ...@@ -289,27 +296,53 @@ static pten::KernelContext BuildDygraphPtenKernelContext(
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto& in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = ins.at(input_names[i]);
if (kernel_ctx->InputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs; paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto var : ins_vector) { for (const auto& var : ins_vector) {
const auto& variable = var->Var(); const auto& variable = var->Var();
tmp_inputs.emplace_back( tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(variable, in_def)); experimental::MakePtenTensorBaseFromVar(variable, in_def));
}
kernel_ctx->EmplaceBackInputs(std::move(tmp_inputs));
} else {
size_t input_size = kernel_ctx->InputsSize();
for (size_t j = 0; j < ins_vector.size(); ++j) {
if (input_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
ins_vector[j]->Var(), in_def,
kernel_ctx->MutableInputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-input case later
}
kernel_ctx->MutableInputRangeAt(i) =
std::make_pair(i, i + ins_vector.size());
} }
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i); auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]); auto& outs_vector = outs.at(output_names[i]);
if (kernel_ctx->OutputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs; paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) { for (auto& var : outs_vector) {
auto* variable = var->MutableVar(); auto* variable = var->MutableVar();
tmp_outputs.emplace_back( tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(variable, out_def)); experimental::MakePtenTensorBaseFromVar(variable, out_def));
}
kernel_ctx->EmplaceBackOutputs(std::move(tmp_outputs));
} else {
size_t output_size = kernel_ctx->OutputsSize();
for (size_t j = 0; j < outs_vector.size(); ++j) {
if (output_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[j]->MutableVar(), out_def,
kernel_ctx->MutableOutputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-output case later
}
kernel_ctx->MutableOutputRangeAt(i) =
std::make_pair(i, i + outs_vector.size());
} }
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
...@@ -319,11 +352,11 @@ static pten::KernelContext BuildDygraphPtenKernelContext( ...@@ -319,11 +352,11 @@ static pten::KernelContext BuildDygraphPtenKernelContext(
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs // attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
op_kernel_ctx.EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -334,11 +367,11 @@ static pten::KernelContext BuildDygraphPtenKernelContext( ...@@ -334,11 +367,11 @@ static pten::KernelContext BuildDygraphPtenKernelContext(
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "unsupported cast op attribute `%s` when construct "
...@@ -347,8 +380,6 @@ static pten::KernelContext BuildDygraphPtenKernelContext( ...@@ -347,8 +380,6 @@ static pten::KernelContext BuildDygraphPtenKernelContext(
} }
} }
} }
return op_kernel_ctx;
} }
template <typename VarType> template <typename VarType>
...@@ -409,20 +440,23 @@ template <typename VarType> ...@@ -409,20 +440,23 @@ template <typename VarType>
static void PreparedOpRunPtImpl( static void PreparedOpRunPtImpl(
const framework::OperatorBase& op, const framework::OperatorBase& op,
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx, const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs, DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type()); &default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape( static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx); &infer_shape_ctx);
auto op_kernel_ctx = BuildDygraphPtenKernelContext<VarType>( BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
pt_kernel_signature, pt_kernel, ins, outs, attrs, default_attrs, outs, attrs, default_attrs, dev_ctx,
*dev_ctx); pt_kernel_context);
pt_kernel(pt_kernel_context);
pt_kernel(&op_kernel_ctx); // Ensure that it does not affect the VarBase life cycle management
pt_kernel_context->ClearData();
// TODO(chenweihang): add debug flags later // TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later // TODO(chenweihang): deal with complex cases later
...@@ -434,7 +468,8 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -434,7 +468,8 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_, PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_,
dev_ctx_, ins, outs, attrs, default_attrs); pt_kernel_context_, dev_ctx_, ins, outs, attrs,
default_attrs);
} else { } else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs); outs, attrs, default_attrs);
...@@ -447,8 +482,8 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -447,8 +482,8 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_, PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_,
dev_ctx_, ins, outs, attrs, pt_kernel_context_, dev_ctx_, ins,
default_attrs); outs, attrs, default_attrs);
} else { } else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs); ins, outs, attrs, default_attrs);
......
...@@ -155,21 +155,25 @@ class PreparedOp { ...@@ -155,21 +155,25 @@ class PreparedOp {
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, const framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); const pten::Kernel& pt_kernel,
pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context = nullptr);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins, static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context = nullptr);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out, void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
...@@ -194,6 +198,9 @@ class PreparedOp { ...@@ -194,6 +198,9 @@ class PreparedOp {
bool run_pten_kernel_{false}; bool run_pten_kernel_{false};
framework::KernelSignature pt_kernel_signature_; framework::KernelSignature pt_kernel_signature_;
pten::Kernel pt_kernel_; pten::Kernel pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
pten::KernelContext* pt_kernel_context_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -213,6 +213,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -213,6 +213,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place); OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception); framework::AppendErrorOpHint(type, &exception);
// Compatible impl: clear pten kernel context data when throw error
OpBase::GetKernelContext()->ClearData();
throw std::move(exception); throw std::move(exception);
} catch (std::exception& ex) { } catch (std::exception& ex) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
......
...@@ -38,7 +38,7 @@ Tensor full(const std::vector<int64_t>& shape, ...@@ -38,7 +38,7 @@ Tensor full(const std::vector<int64_t>& shape,
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx); auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform // 3. Auto data transform
kernel_context.EmplaceBackAttr(value); kernel_context.EmplaceBackAttr(value);
...@@ -75,7 +75,7 @@ Tensor full_like(const Tensor& x, ...@@ -75,7 +75,7 @@ Tensor full_like(const Tensor& x,
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx); auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform // 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
......
...@@ -38,7 +38,7 @@ Tensor dot(const Tensor& x, const Tensor& y) { ...@@ -38,7 +38,7 @@ Tensor dot(const Tensor& x, const Tensor& y) {
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx); auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform // 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
...@@ -76,7 +76,7 @@ Tensor matmul(const Tensor& x, ...@@ -76,7 +76,7 @@ Tensor matmul(const Tensor& x,
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx); auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform // 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
......
...@@ -34,7 +34,7 @@ Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { ...@@ -34,7 +34,7 @@ Tensor flatten(const Tensor& x, int start_axis, int stop_axis) {
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx); auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform // 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
......
...@@ -36,7 +36,7 @@ Tensor mean(const Tensor& x) { ...@@ -36,7 +36,7 @@ Tensor mean(const Tensor& x) {
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx); auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform // 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
......
...@@ -75,6 +75,24 @@ class SharedStorage : public pten::Storage { ...@@ -75,6 +75,24 @@ class SharedStorage : public pten::Storage {
return allocation_; return allocation_;
} }
// Temporary method: For compatible with fluid Tensor and improve performance
void ResetAllocation(std::shared_ptr<paddle::memory::Allocation> allocation,
size_t offset) {
allocation_ = allocation;
data_ = pten::Allocation(
reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(allocation->ptr()) +
offset),
allocation->place());
size_ = allocation->size();
}
// Temporary method: For compatible with fluid Tensor and improve performance
void Reset() {
allocation_.reset();
data_.Clear();
size_ = 0;
}
private: private:
int64_t size_{0}; int64_t size_{0};
std::shared_ptr<paddle::memory::Allocation> allocation_; std::shared_ptr<paddle::memory::Allocation> allocation_;
......
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include <vector>
#include "paddle/pten/core/compat_utils.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -126,5 +130,101 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) { ...@@ -126,5 +130,101 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) {
MovesStorage(src, static_cast<paddle::framework::Tensor*>(dst)); MovesStorage(src, static_cast<paddle::framework::Tensor*>(dst));
} }
void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
pten::DenseTensor* dst) {
auto* meta = pten::CompatibleDenseTensorUtils::GetMutableMeta(dst);
meta->dims = src.dims();
// Since the type of DenseTensorMeta is const, const_cast must be used
const_cast<DataType&>(meta->type) = pten::TransToPtenDataType(src.type());
// Since the type of DenseTensorMeta is const, const_cast must be used
const_cast<DataLayout&>(meta->layout) =
pten::TransToPtenDataLayout(src.layout());
auto* shared_storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
PADDLE_ENFORCE_NOT_NULL(
shared_storage,
platform::errors::NotFound(
"Target DenseTensor's shared storage is nullptr."));
shared_storage->ResetAllocation(src.Holder(), src.offset());
}
void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src,
pten::DenseTensor* dst) {
auto* meta = pten::CompatibleDenseTensorUtils::GetMutableMeta(dst);
meta->dims = src.dims();
// Since the type of DenseTensorMeta is const, const_cast must be used
const_cast<DataType&>(meta->type) = pten::TransToPtenDataType(src.type());
// Since the type of DenseTensorMeta is const, const_cast must be used
const_cast<DataLayout&>(meta->layout) =
pten::TransToPtenDataLayout(src.layout());
SetLoD(&(meta->lod), src.lod());
auto* shared_storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
PADDLE_ENFORCE_NOT_NULL(
shared_storage,
platform::errors::NotFound(
"Target DenseTensor's shared storage is nullptr."));
shared_storage->ResetAllocation(src.Holder(), src.offset());
}
void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst) {
auto expected_place = pten::TransToFluidPlace(arg_def.backend);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
ReMakePtenDenseTensor(tmp_tensor, dst);
} else {
ReMakePtenDenseTensor(tensor, dst);
}
} else if (variable.IsType<framework::SelectedRows>()) {
// TODO(chenweihang): now we don't deal with row and height
// by xiaowei's advice
const auto& tensor = variable.Get<framework::SelectedRows>();
if (!platform::is_same_place(tensor.value().place(), expected_place)) {
framework::Tensor tmp_tensor;
TensorCopySync(tensor.value(), expected_place, &tmp_tensor);
// TODO(chenweihang): adapt SelectedRows by xiaowei's design
ReMakePtenDenseTensor(tmp_tensor, dst);
} else {
ReMakePtenDenseTensor(tensor.value(), dst);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported shared input `%s` type now when call pt kernel.",
framework::ToTypeName(variable.Type())));
}
}
void ReMakePtenDenseTensorFromVar(framework::Variable* variable,
const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst) {
// mutable_data before run kernel, to avoid share output form
// KernelContext to original tensor
if (variable->template IsType<framework::LoDTensor>()) {
auto* tensor = variable->template GetMutable<framework::LoDTensor>();
// TODO(chenweihang): use original var type if arg_def.dtype is UNDEFINED
tensor->mutable_data(pten::TransToFluidPlace(arg_def.backend),
pten::TransToProtoVarType(arg_def.dtype));
ReMakePtenDenseTensor(*tensor, dst);
} else if (variable->template IsType<framework::SelectedRows>()) {
auto* tensor = variable->template GetMutable<framework::SelectedRows>();
tensor->mutable_value()->mutable_data(
pten::TransToFluidPlace(arg_def.backend),
pten::TransToProtoVarType(arg_def.dtype));
// TODO(chenweihang): adapt SelectedRows by xiaowei's design,
// here the row and height will lost in output!
ReMakePtenDenseTensor(tensor->value(), dst);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported shared output `%s` type now when call pt kernel.",
framework::ToTypeName(variable->Type())));
}
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -44,5 +44,29 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst); ...@@ -44,5 +44,29 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst);
void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst); void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst);
/**
* In order to improve the compatibility state performance, some tricky tool
* functions are added.
*
* The ReMake** function takes out the LoDTensor information and directly
* replaces it with the corresponding member of the DenseTensor to avoid
* the overhead caused by frequent construction and destruction of the
* DenseTensor.
*/
void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
pten::DenseTensor* dst);
void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src,
pten::DenseTensor* dst);
void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst);
void ReMakePtenDenseTensorFromVar(framework::Variable* variable,
const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/storage.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten {
/**
* In order to meet some adaptation requirements of the compatible state,
* these class is added to provide some tool functions.
*
* These utility functions may be deleted in the future, It is not recommended
* to be widely used in the framework
*/
class CompatibleDenseTensorUtils {
public:
static Storage* UnsafeGetMutableStorage(DenseTensor* tensor) {
return tensor->storage_.get();
}
static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) {
return &(tensor->meta_);
}
// only can deal with SharedStorage now
static void ClearStorage(DenseTensor* tensor) {
// use static_cast to improve performance, replace by dynamic_cast later
static_cast<paddle::experimental::SharedStorage*>(tensor->storage_.get())
->Reset();
}
};
} // namespace pten
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
namespace pten { namespace pten {
class CompatibleDenseTensorUtils;
/// \brief The Dense tensor store values in a contiguous sequential block /// \brief The Dense tensor store values in a contiguous sequential block
/// of memory where all values are represented. Tensors or multi-dimensional /// of memory where all values are represented. Tensors or multi-dimensional
/// arrays are used in math operators. /// arrays are used in math operators.
...@@ -164,6 +166,9 @@ class DenseTensor : public TensorBase, ...@@ -164,6 +166,9 @@ class DenseTensor : public TensorBase,
/// \return The const data pointer value of raw type. /// \return The const data pointer value of raw type.
const void* data() const; const void* data() const;
private:
friend class CompatibleDenseTensorUtils;
private: private:
DenseTensorMeta meta_; DenseTensorMeta meta_;
intrusive_ptr<Storage> storage_; intrusive_ptr<Storage> storage_;
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#pragma once #pragma once
#include <iterator>
#include <utility> #include <utility>
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/tensor_base.h" #include "paddle/pten/core/tensor_base.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
...@@ -39,16 +41,14 @@ using DataLayout = paddle::experimental::DataLayout; ...@@ -39,16 +41,14 @@ using DataLayout = paddle::experimental::DataLayout;
*/ */
class KernelContext { class KernelContext {
public: public:
explicit KernelContext(const DeviceContext& dev_ctx) : dev_ctx_(dev_ctx) {} KernelContext() = default;
KernelContext(const DeviceContext& dev_ctx, explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {}
const paddle::SmallVector<std::shared_ptr<TensorBase>>& inputs,
const paddle::SmallVector<std::shared_ptr<TensorBase>>& outputs, void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
const paddle::SmallVector<paddle::any>& attrs)
: dev_ctx_(dev_ctx), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
template <typename CtxType> template <typename CtxType>
const CtxType& GetDeviceContext() const { const CtxType& GetDeviceContext() const {
return static_cast<const CtxType&>(dev_ctx_); return static_cast<const CtxType&>(*dev_ctx_);
} }
void EmplaceBackInput(std::shared_ptr<TensorBase> input) { void EmplaceBackInput(std::shared_ptr<TensorBase> input) {
...@@ -59,14 +59,14 @@ class KernelContext { ...@@ -59,14 +59,14 @@ class KernelContext {
} }
void EmplaceBackInputs( void EmplaceBackInputs(
const paddle::SmallVector<std::shared_ptr<TensorBase>>& inputs) { paddle::SmallVector<std::shared_ptr<TensorBase>> inputs) {
int index = inputs_.size(); int index = inputs_.size();
for (auto in : inputs) {
inputs_.emplace_back(std::move(in));
}
// Record the start and end index of the input // Record the start and end index of the input
input_range_.emplace_back( input_range_.emplace_back(
std::pair<int, int>(index, index + inputs.size())); std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
} }
void EmplaceBackOutput(std::shared_ptr<TensorBase> output) { void EmplaceBackOutput(std::shared_ptr<TensorBase> output) {
...@@ -77,14 +77,14 @@ class KernelContext { ...@@ -77,14 +77,14 @@ class KernelContext {
} }
void EmplaceBackOutputs( void EmplaceBackOutputs(
const paddle::SmallVector<std::shared_ptr<TensorBase>>& outputs) { paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
int index = outputs_.size(); int index = outputs_.size();
for (auto out : outputs) {
outputs_.emplace_back(std::move(out));
}
// Record the start and end index of the input // Record the start and end index of the input
output_range_.emplace_back( output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size())); std::pair<int, int>(index, index + outputs.size()));
outputs_.insert(outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
} }
void EmplaceBackAttr(paddle::any attr) { void EmplaceBackAttr(paddle::any attr) {
...@@ -115,6 +115,19 @@ class KernelContext { ...@@ -115,6 +115,19 @@ class KernelContext {
return output_range_.at(idx); return output_range_.at(idx);
} }
std::pair<int, int>& MutableInputRangeAt(size_t idx) {
return input_range_[idx];
}
std::pair<int, int>& MutableOutputRangeAt(size_t idx) {
return output_range_[idx];
}
template <typename TensorType>
TensorType* MutableInputAt(size_t idx) {
return static_cast<TensorType*>(inputs_.at(idx).get());
}
template <typename TensorType> template <typename TensorType>
TensorType* MutableOutputAt(size_t idx) { TensorType* MutableOutputAt(size_t idx) {
return static_cast<TensorType*>(outputs_.at(idx).get()); return static_cast<TensorType*>(outputs_.at(idx).get());
...@@ -140,12 +153,30 @@ class KernelContext { ...@@ -140,12 +153,30 @@ class KernelContext {
} }
} }
// Temporary method: For compatible with fluid Tensor and improve performance
// Only deal with DenseTensor now
void ClearData() {
for (auto& in : inputs_) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(in.get()));
}
for (auto& out : outputs_) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(out.get()));
}
attrs_.clear();
}
size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); }
size_t AttrsSize() const { return attrs_.size(); }
private: private:
bool IsDuplicable() const { return input_range_.size() != inputs_.size(); } bool IsDuplicable() const { return input_range_.size() != inputs_.size(); }
private: private:
// DeviceContext base class // DeviceContext base class
const DeviceContext& dev_ctx_; DeviceContext* dev_ctx_;
// TODO(chenweihang): Tensor -> Tensor*, Tensor should by managed `scope` // TODO(chenweihang): Tensor -> Tensor*, Tensor should by managed `scope`
// Note: can't use API Tensor here, the inference don't use this API Tensor // Note: can't use API Tensor here, the inference don't use this API Tensor
...@@ -156,11 +187,6 @@ class KernelContext { ...@@ -156,11 +187,6 @@ class KernelContext {
// Only contains input like list[Tensor] need `range` // Only contains input like list[Tensor] need `range`
paddle::SmallVector<std::pair<int, int>> input_range_; paddle::SmallVector<std::pair<int, int>> input_range_;
paddle::SmallVector<std::pair<int, int>> output_range_; paddle::SmallVector<std::pair<int, int>> output_range_;
// Only static graph need `name`
// TODO(chenweihang): replaced by paddle::string_view
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
}; };
} // namespace pten } // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册