未验证 提交 9495708a 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick2.3] Optimize dygraph performance part3 (#42256)

* Change small vector size (#42202)

* change samll vector size

* Update type_defs.h

* Optimize dygraph InferShape perf (#42155)

* init commit

* remove two hash impl

* fix bug

* polish details

* fix compile failed

* fix compile failed

* fix compile failed

* add default kernel sig cache

* fix get kernel arg defs error

* remove kernel arg defs cache

* fix origin op execute
上级 f16087e0
...@@ -402,11 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween( ...@@ -402,11 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) { const std::string& op_type) {
// 1. get kernel args // 1. get kernel args
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
InferShapeArgumentMappingContext arg_map_context(*ctx); InferShapeArgumentMappingContext arg_map_context(*ctx);
KernelSignature signature = phi::KernelSignature signature = arg_map_fn
arg_map_fn ? (*arg_map_fn)(arg_map_context) ? (*arg_map_fn)(arg_map_context)
: phi::DefaultKernelSignatureMap::Instance().Get(op_type); : *ctx->GetPhiDefaultKernelSignature();
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context // 2. build infermeta context
......
...@@ -393,6 +393,16 @@ void InterpretercoreInferShapeContext::SetOutputsDim( ...@@ -393,6 +393,16 @@ void InterpretercoreInferShapeContext::SetOutputsDim(
SetDims(vars, dims); SetDims(vars, dims);
} }
const phi::ArgumentMappingFn*
InterpretercoreInferShapeContext::GetPhiArgumentMappingFn() const {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature*
InterpretercoreInferShapeContext::GetPhiDefaultKernelSignature() const {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) { void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) {
can_skip_lod_ = skip; can_skip_lod_ = skip;
} }
......
...@@ -111,6 +111,10 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -111,6 +111,10 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
void SetOutputsDim(const std::string& name, void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override; const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip); void SetSkipLoD(bool skip);
protected: protected:
......
...@@ -271,6 +271,14 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -271,6 +271,14 @@ class CompileTimeInferShapeContext : public InferShapeContext {
SetDims(names, dims); SetDims(names, dims);
} }
const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
protected: protected:
std::vector<proto::VarType::Type> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
......
...@@ -1007,6 +1007,14 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -1007,6 +1007,14 @@ class RuntimeInferShapeContext : public InferShapeContext {
SetDims(vars, dims); SetDims(vars, dims);
} }
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
protected: protected:
DDim GetDim(Variable* var) const { DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -1279,16 +1287,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1279,16 +1287,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelKey pt_kernel_key; phi::KernelKey pt_kernel_key;
std::string pt_kernel_name; std::string pt_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { if (kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset( kernel_signature_.reset(new phi::KernelSignature(
new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx)))); std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *kernel_signature_.get();
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
pt_kernel_name = pt_kernel_signature_->name; pt_kernel_name = kernel_signature_->name;
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset( pt_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
...@@ -1303,7 +1311,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1303,7 +1311,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
<< "` not found."; << "` not found.";
} }
} else { } else {
pt_kernel_name = pt_kernel_signature_->name; pt_kernel_name = kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work. // library_type here, otherwise it can't work.
...@@ -1449,8 +1457,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1449,8 +1457,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext // Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack // TODO(zhiqiu): support TransferInplaceVarsBack
PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_, PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
runtime_ctx);
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context); (*pt_kernel_)(&pt_kernel_context);
} else { } else {
...@@ -1545,14 +1552,14 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1545,14 +1552,14 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
phi::KernelKey OperatorWithKernel::ChoosePhiKernel( phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
pt_kernel_signature_.reset( kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *kernel_signature_.get();
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto pt_kernel_name = pt_kernel_signature_->name; auto pt_kernel_name = kernel_signature_->name;
auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key))); pt_kernel_name, pt_kernel_key)));
...@@ -2153,7 +2160,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -2153,7 +2160,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout()); tensor.layout());
} }
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
ExecutionArgumentMappingContext arg_mapping_ctx(ctx); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
if (arg_map_fn_ == nullptr) { if (arg_map_fn_ == nullptr) {
...@@ -2161,8 +2168,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( ...@@ -2161,8 +2168,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
if (arg_map_fn) { if (arg_map_fn) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn)); arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
} else { } else {
auto func = auto func = [this](
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature { const phi::ArgumentMappingContext& ctx) -> phi::KernelSignature {
return phi::DefaultKernelSignatureMap::Instance().Get(type_); return phi::DefaultKernelSignatureMap::Instance().Get(type_);
}; };
arg_map_fn_.reset(new phi::ArgumentMappingFn(func)); arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
...@@ -2173,7 +2180,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( ...@@ -2173,7 +2180,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
Scope* OperatorWithKernel::PreparePhiData( Scope* OperatorWithKernel::PreparePhiData(
const Scope& scope, const phi::Kernel& pt_kernel, const Scope& scope, const phi::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const { const phi::KernelSignature& pt_kernel_signature,
RuntimeContext* ctx) const {
const auto& input_names = pt_kernel_signature.input_names; const auto& input_names = pt_kernel_signature.input_names;
auto input_defs = pt_kernel.args_def().input_defs(); auto input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
...@@ -2271,9 +2279,9 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2271,9 +2279,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi::KernelContext* pt_kernel_context) const { phi::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx); pt_kernel_context->SetDeviceContext(dev_ctx);
auto& input_names = pt_kernel_signature_->input_names; auto& input_names = kernel_signature_->input_names;
auto& attr_names = pt_kernel_signature_->attr_names; auto& attr_names = kernel_signature_->attr_names;
auto& output_names = pt_kernel_signature_->output_names; auto& output_names = kernel_signature_->output_names;
auto input_defs = pt_kernel_->args_def().input_defs(); auto input_defs = pt_kernel_->args_def().input_defs();
auto attr_defs = pt_kernel_->args_def().attribute_defs(); auto attr_defs = pt_kernel_->args_def().attribute_defs();
......
...@@ -632,7 +632,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -632,7 +632,7 @@ class OperatorWithKernel : public OperatorBase {
phi::KernelContext* pt_kernel_context) const; phi::KernelContext* pt_kernel_context) const;
phi::KernelSignature* PhiKernelSignature() const { phi::KernelSignature* PhiKernelSignature() const {
return pt_kernel_signature_.get(); return kernel_signature_.get();
} }
phi::Kernel* PhiKernel() const { return pt_kernel_.get(); } phi::Kernel* PhiKernel() const { return pt_kernel_.get(); }
...@@ -704,7 +704,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -704,7 +704,7 @@ class OperatorWithKernel : public OperatorBase {
// we may polish the implementation here // we may polish the implementation here
mutable bool run_phi_kernel_ = false; mutable bool run_phi_kernel_ = false;
mutable bool run_kp_kernel = false; mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_; mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_; mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
}; };
......
...@@ -45,7 +45,7 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ...@@ -45,7 +45,7 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
const paddle::SmallVector<const char*>& GetOutputArgsNames() override; const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
const paddle::SmallVector<const char*>& GetAttrsArgsNames() override; const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature(); phi::KernelSignature GetKernelSignature();
private: private:
DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto); DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);
...@@ -221,10 +221,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -221,10 +221,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
return attr_names_; return attr_names_;
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { phi::KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(), return phi::KernelSignature(
GetInputArgsNames(), GetAttrsArgsNames(), phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(),
GetOutputArgsNames()); GetAttrsArgsNames(), GetOutputArgsNames());
} }
std::once_flag kernel_sig_map_init_flag; std::once_flag kernel_sig_map_init_flag;
......
...@@ -40,8 +40,6 @@ limitations under the License. */ ...@@ -40,8 +40,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using KernelSignature = phi::KernelSignature;
/* Kernel Key translate */ /* Kernel Key translate */
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key); OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
......
...@@ -113,6 +113,10 @@ class InferShapeContext { ...@@ -113,6 +113,10 @@ class InferShapeContext {
virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string &name) const = 0; GetOutputVarPtrs(const std::string &name) const = 0;
virtual const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const = 0;
virtual const phi::KernelSignature *GetPhiDefaultKernelSignature() const = 0;
protected: protected:
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name, virtual void SetRepeatedDims(const std::string &name,
......
...@@ -37,13 +37,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -37,13 +37,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* in, const NameVarMap<VarType>* out, const NameVarMap<VarType>* in, const NameVarMap<VarType>* out,
const framework::AttributeMap* attr, const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr, const std::string op_type, const framework::AttributeMap* default_attr, const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr) const framework::OpKernelType* op_kernel_type = nullptr,
const phi::ArgumentMappingFn* arg_map_fn = nullptr,
const phi::KernelSignature* default_kernel_signature = nullptr)
: var_map_in_(in), : var_map_in_(in),
var_map_out_(out), var_map_out_(out),
attrs_(attr), attrs_(attr),
default_attrs_(default_attr), default_attrs_(default_attr),
op_type_(op_type), op_type_(op_type),
op_kernel_type_(op_kernel_type) {} op_kernel_type_(op_kernel_type),
arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
...@@ -377,6 +381,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -377,6 +381,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
"SetLoDLevel function not support in dygraph mode")); "SetLoDLevel function not support in dygraph mode"));
} }
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
return arg_map_fn_;
}
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
return default_kernel_signature_;
}
protected: protected:
DDim GetDim(framework::Variable* var) const { DDim GetDim(framework::Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
...@@ -438,6 +450,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -438,6 +450,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* default_attrs_; const framework::AttributeMap* default_attrs_;
const std::string op_type_; const std::string op_type_;
const framework::OpKernelType* op_kernel_type_; const framework::OpKernelType* op_kernel_type_;
// arg_map_fn_ and default_kernel_signature_ may be nullptr
const phi::ArgumentMappingFn* arg_map_fn_;
const phi::KernelSignature* default_kernel_signature_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -107,19 +107,25 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -107,19 +107,25 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), : op_(op),
ctx_(ctx), ctx_(ctx),
kernel_type_(kernel_type), kernel_type_(kernel_type),
func_(func), func_(func),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
pt_kernel_(empty_kernel) {} arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature),
phi_kernel_(empty_kernel) {}
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
framework::KernelSignature&& kernel_signature, const phi::ArgumentMappingFn* arg_map_fn,
const phi::Kernel& pt_kernel, const phi::KernelSignature* default_kernel_signature,
phi::KernelSignature&& kernel_signature,
const phi::Kernel& phi_kernel,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), : op_(op),
ctx_(ctx), ctx_(ctx),
...@@ -127,8 +133,10 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -127,8 +133,10 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
func_(nullptr), func_(nullptr),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
run_phi_kernel_(true), run_phi_kernel_(true),
pt_kernel_signature_(std::move(kernel_signature)), arg_map_fn_(arg_map_fn),
pt_kernel_(pt_kernel) {} default_kernel_signature_(default_kernel_signature),
kernel_signature_(std::move(kernel_signature)),
phi_kernel_(phi_kernel) {}
template <typename VarType> template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
...@@ -161,7 +169,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -161,7 +169,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs); op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
framework::KernelSignature pt_kernel_signature; const phi::KernelSignature* default_kernel_signature = nullptr;
phi::KernelSignature kernel_signature;
phi::KernelKey pt_kernel_key; phi::KernelKey pt_kernel_key;
std::string pt_kernel_name; std::string pt_kernel_name;
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
...@@ -179,20 +188,20 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -179,20 +188,20 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type()); phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) { if (arg_map_fn) {
has_phi_kernel = true; has_phi_kernel = true;
pt_kernel_signature = (*arg_map_fn)( kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else { } else {
const auto* kernel_sig = default_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type()); phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
if (kernel_sig) { if (default_kernel_signature) {
has_phi_kernel = true; has_phi_kernel = true;
pt_kernel_signature = *kernel_sig; kernel_signature = *default_kernel_signature;
} }
} }
if (has_phi_kernel) { if (has_phi_kernel) {
VLOG(6) << pt_kernel_signature; VLOG(6) << kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work. // library_type here, otherwise it can't work.
...@@ -230,24 +239,25 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -230,24 +239,25 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel( auto& phi_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key); pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid() if (phi_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&& !is_xpu_unsupport && !is_xpu_unsupport
#endif #endif
) { ) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key << " | kernel key: " << pt_kernel_key
<< " | kernel: " << pt_kernel; << " | kernel: " << phi_kernel;
if (expected_kernel_key.place_ != place) { if (expected_kernel_key.place_ != place) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
return PreparedOp(op, empty_ctx, expected_kernel_key, return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn,
std::move(pt_kernel_signature), pt_kernel, dev_ctx); default_kernel_signature, std::move(kernel_signature),
phi_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
...@@ -295,9 +305,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -295,9 +305,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel; << " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, empty_ctx, expected_kernel_key, return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn,
std::move(pt_kernel_signature), pt_cpu_kernel, default_kernel_signature, std::move(kernel_signature),
cpu_ctx); pt_cpu_kernel, cpu_ctx);
} }
} }
} }
...@@ -389,7 +399,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -389,7 +399,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
} }
return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second, return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
dev_ctx); arg_map_fn, default_kernel_signature, dev_ctx);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
...@@ -425,6 +435,8 @@ static void PreparedOpRunImpl( ...@@ -425,6 +435,8 @@ static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
...@@ -436,7 +448,8 @@ static void PreparedOpRunImpl( ...@@ -436,7 +448,8 @@ static void PreparedOpRunImpl(
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx( DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type); &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type,
arg_map_fn, default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
} }
...@@ -483,17 +496,19 @@ template <typename VarType> ...@@ -483,17 +496,19 @@ template <typename VarType>
static void PreparedOpRunPtImpl( static void PreparedOpRunPtImpl(
const framework::OperatorBase& op, const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature, const phi::ArgumentMappingFn* arg_map_fn,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx, const phi::KernelSignature* default_kernel_signature,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs, const phi::KernelSignature& kernel_signature, const phi::Kernel& phi_kernel,
const framework::AttributeMap& attrs, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
{ {
platform::RecordEvent record_event(op.Type() + "::infer_shape", platform::RecordEvent record_event(op.Type() + "::infer_shape",
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx( DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type); &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type,
arg_map_fn, default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
} }
...@@ -502,14 +517,14 @@ static void PreparedOpRunPtImpl( ...@@ -502,14 +517,14 @@ static void PreparedOpRunPtImpl(
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
PreparePhiData<VarType>(pt_kernel, pt_kernel_signature, ins); PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext pt_kernel_context; phi::KernelContext pt_kernel_context;
BuildDygraphPhiKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins, BuildDygraphPhiKernelContext<VarType>(kernel_signature, phi_kernel, ins,
outs, attrs, default_attrs, dev_ctx, outs, attrs, default_attrs, dev_ctx,
&pt_kernel_context); &pt_kernel_context);
pt_kernel(&pt_kernel_context); phi_kernel(&pt_kernel_context);
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
...@@ -535,12 +550,14 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -535,12 +550,14 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) { if (run_phi_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_, PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, arg_map_fn_,
pt_kernel_, dev_ctx_, ins, outs, attrs, default_kernel_signature_, kernel_signature_,
phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs); default_attrs);
} else { } else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, arg_map_fn_,
outs, attrs, default_attrs); default_kernel_signature_, dev_ctx_, ins, outs,
attrs, default_attrs);
} }
} }
...@@ -550,11 +567,13 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -550,11 +567,13 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) { if (run_phi_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>( PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, op_, kernel_type_, arg_map_fn_, default_kernel_signature_,
outs, attrs, default_attrs); kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
} else { } else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_, PreparedOpRunImpl<VariableWrapper>(
ins, outs, attrs, default_attrs); op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_,
dev_ctx_, ins, outs, attrs, default_attrs);
} }
} }
...@@ -564,12 +583,13 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins, ...@@ -564,12 +583,13 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) { if (run_phi_kernel_) {
PreparedOpRunPtImpl<egr::EagerVariable>( PreparedOpRunPtImpl<egr::EagerVariable>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, op_, kernel_type_, arg_map_fn_, default_kernel_signature_,
outs, attrs, default_attrs); kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs,
} else {
PreparedOpRunImpl<egr::EagerVariable>(op_, ctx_, kernel_type_, func_,
dev_ctx_, ins, outs, attrs,
default_attrs); default_attrs);
} else {
PreparedOpRunImpl<egr::EagerVariable>(
op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_,
dev_ctx_, ins, outs, attrs, default_attrs);
} }
} }
......
...@@ -150,13 +150,17 @@ class PreparedOp { ...@@ -150,13 +150,17 @@ class PreparedOp {
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx); platform::DeviceContext* dev_ctx);
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
framework::KernelSignature&& kernel_signature, const phi::ArgumentMappingFn* arg_map_fn,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); const phi::KernelSignature* default_kernel_signature,
phi::KernelSignature&& kernel_signature,
const phi::Kernel& phi_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
...@@ -206,8 +210,10 @@ class PreparedOp { ...@@ -206,8 +210,10 @@ class PreparedOp {
// we may polish the implementation here // we may polish the implementation here
bool run_phi_kernel_{false}; bool run_phi_kernel_{false};
bool run_kp_kernel_{false}; bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_; const phi::ArgumentMappingFn* arg_map_fn_;
const phi::Kernel& pt_kernel_; const phi::KernelSignature* default_kernel_signature_;
phi::KernelSignature kernel_signature_;
const phi::Kernel& phi_kernel_;
}; };
const inline framework::Attribute& GetAttr( const inline framework::Attribute& GetAttr(
...@@ -226,21 +232,23 @@ const inline framework::Attribute& GetAttr( ...@@ -226,21 +232,23 @@ const inline framework::Attribute& GetAttr(
} }
template <typename VarType> template <typename VarType>
void BuildDygraphPhiKernelContext( void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
const framework::KernelSignature& pt_kernel_signature, const phi::Kernel& phi_kernel,
const phi::Kernel& pt_kernel, const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) { platform::DeviceContext* dev_ctx,
phi::KernelContext* kernel_ctx) {
kernel_ctx->SetDeviceContext(dev_ctx); kernel_ctx->SetDeviceContext(dev_ctx);
const auto& input_names = pt_kernel_signature.input_names; const auto& input_names = kernel_signature.input_names;
const auto& attr_names = pt_kernel_signature.attr_names; const auto& attr_names = kernel_signature.attr_names;
const auto& output_names = pt_kernel_signature.output_names; const auto& output_names = kernel_signature.output_names;
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = phi_kernel.args_def().input_defs();
auto& output_defs = pt_kernel.args_def().output_defs(); auto& output_defs = phi_kernel.args_def().output_defs();
auto& attr_defs = pt_kernel.args_def().attribute_defs(); auto& attr_defs = phi_kernel.args_def().attribute_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -286,7 +294,7 @@ void BuildDygraphPhiKernelContext( ...@@ -286,7 +294,7 @@ void BuildDygraphPhiKernelContext(
"Can not find input variable '%s' for %s OP, please check whether " "Can not find input variable '%s' for %s OP, please check whether "
"the name setting in OpArgumentMapping is consistent with that in " "the name setting in OpArgumentMapping is consistent with that in "
"OpMaker.", "OpMaker.",
input_names[i], pt_kernel_signature.name)); input_names[i], kernel_signature.name));
} }
} }
...@@ -568,11 +576,11 @@ void BuildDygraphPhiKernelContext( ...@@ -568,11 +576,11 @@ void BuildDygraphPhiKernelContext(
} }
template <typename VarType> template <typename VarType>
void PreparePhiData(const phi::Kernel& pt_kernel, void PreparePhiData(const phi::Kernel& phi_kernel,
const framework::KernelSignature& pt_kernel_signature, const phi::KernelSignature& kernel_signature,
const NameVarMap<VarType>& ins) { const NameVarMap<VarType>& ins) {
const auto& input_names = pt_kernel_signature.input_names; const auto& input_names = kernel_signature.input_names;
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = phi_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -38,8 +38,8 @@ using ArgumentMappingFn = ...@@ -38,8 +38,8 @@ using ArgumentMappingFn =
using InferMetaFn = void (*)(InferMetaContext* ctx); using InferMetaFn = void (*)(InferMetaContext* ctx);
// Global SmallVector size setting // Global SmallVector size setting
constexpr size_t kInputSmallVectorSize = 10U; constexpr size_t kInputSmallVectorSize = 15U;
constexpr size_t kAttrSmallVectorSize = 10U; constexpr size_t kAttrSmallVectorSize = 15U;
constexpr size_t kOutputSmallVectorSize = 5U; constexpr size_t kOutputSmallVectorSize = 15U;
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册