未验证 提交 6721376b 编写于 作者: C Chen Weihang 提交者: GitHub

Optimize dygraph InferShape perf (#42155)

* init commit

* remove two hash impl

* fix bug

* polish details

* fix compile failed

* fix compile failed

* fix compile failed

* add default kernel sig cache

* fix get kernel arg defs error

* remove kernel arg defs cache

* fix origin op execute
上级 192a5af5
......@@ -402,11 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
InferShapeArgumentMappingContext arg_map_context(*ctx);
KernelSignature signature =
arg_map_fn ? (*arg_map_fn)(arg_map_context)
: phi::DefaultKernelSignatureMap::Instance().Get(op_type);
phi::KernelSignature signature = arg_map_fn
? (*arg_map_fn)(arg_map_context)
: *ctx->GetPhiDefaultKernelSignature();
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context
......
......@@ -393,6 +393,16 @@ void InterpretercoreInferShapeContext::SetOutputsDim(
SetDims(vars, dims);
}
const phi::ArgumentMappingFn*
InterpretercoreInferShapeContext::GetPhiArgumentMappingFn() const {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature*
InterpretercoreInferShapeContext::GetPhiDefaultKernelSignature() const {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) {
can_skip_lod_ = skip;
}
......
......@@ -111,6 +111,10 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override;
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
void SetSkipLoD(bool skip);
protected:
......
......@@ -271,6 +271,14 @@ class CompileTimeInferShapeContext : public InferShapeContext {
SetDims(names, dims);
}
const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
protected:
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const {
......
......@@ -1005,6 +1005,14 @@ class RuntimeInferShapeContext : public InferShapeContext {
SetDims(vars, dims);
}
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
}
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
}
protected:
DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
......@@ -1277,16 +1285,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelKey pt_kernel_key;
std::string pt_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get();
if (kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *kernel_signature_.get();
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_);
pt_kernel_name = pt_kernel_signature_->name;
pt_kernel_name = kernel_signature_->name;
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
......@@ -1301,7 +1309,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
<< "` not found.";
}
} else {
pt_kernel_name = pt_kernel_signature_->name;
pt_kernel_name = kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
......@@ -1447,8 +1455,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_,
runtime_ctx);
PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
} else {
......@@ -1543,14 +1550,14 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get();
kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
VLOG(6) << *kernel_signature_.get();
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto pt_kernel_name = pt_kernel_signature_->name;
auto pt_kernel_name = kernel_signature_->name;
auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));
......@@ -2151,7 +2158,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout());
}
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const {
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
if (arg_map_fn_ == nullptr) {
......@@ -2159,8 +2166,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
if (arg_map_fn) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
} else {
auto func =
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature {
auto func = [this](
const phi::ArgumentMappingContext& ctx) -> phi::KernelSignature {
return phi::DefaultKernelSignatureMap::Instance().Get(type_);
};
arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
......@@ -2171,7 +2178,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
Scope* OperatorWithKernel::PreparePhiData(
const Scope& scope, const phi::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
const phi::KernelSignature& pt_kernel_signature,
RuntimeContext* ctx) const {
const auto& input_names = pt_kernel_signature.input_names;
auto input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
......@@ -2269,9 +2277,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx);
auto& input_names = pt_kernel_signature_->input_names;
auto& attr_names = pt_kernel_signature_->attr_names;
auto& output_names = pt_kernel_signature_->output_names;
auto& input_names = kernel_signature_->input_names;
auto& attr_names = kernel_signature_->attr_names;
auto& output_names = kernel_signature_->output_names;
auto input_defs = pt_kernel_->args_def().input_defs();
auto attr_defs = pt_kernel_->args_def().attribute_defs();
......
......@@ -632,7 +632,7 @@ class OperatorWithKernel : public OperatorBase {
phi::KernelContext* pt_kernel_context) const;
phi::KernelSignature* PhiKernelSignature() const {
return pt_kernel_signature_.get();
return kernel_signature_.get();
}
phi::Kernel* PhiKernel() const { return pt_kernel_.get(); }
......@@ -704,7 +704,7 @@ class OperatorWithKernel : public OperatorBase {
// we may polish the implementation here
mutable bool run_phi_kernel_ = false;
mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
};
......
......@@ -45,7 +45,7 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
phi::KernelSignature GetKernelSignature();
private:
DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);
......@@ -221,10 +221,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
return attr_names_;
}
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
phi::KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return phi::KernelSignature(
phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(),
GetAttrsArgsNames(), GetOutputArgsNames());
}
std::once_flag kernel_sig_map_init_flag;
......
......@@ -40,8 +40,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
using KernelSignature = phi::KernelSignature;
/* Kernel Key translate */
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
......
......@@ -113,6 +113,10 @@ class InferShapeContext {
virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string &name) const = 0;
virtual const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const = 0;
virtual const phi::KernelSignature *GetPhiDefaultKernelSignature() const = 0;
protected:
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name,
......
......@@ -37,13 +37,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* in, const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr, const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr)
const framework::OpKernelType* op_kernel_type = nullptr,
const phi::ArgumentMappingFn* arg_map_fn = nullptr,
const phi::KernelSignature* default_kernel_signature = nullptr)
: var_map_in_(in),
var_map_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type),
op_kernel_type_(op_kernel_type) {}
op_kernel_type_(op_kernel_type),
arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature) {}
bool HasInput(const std::string& name) const override {
// has only one input
......@@ -377,6 +381,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
"SetLoDLevel function not support in dygraph mode"));
}
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
return arg_map_fn_;
}
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
return default_kernel_signature_;
}
protected:
DDim GetDim(framework::Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
......@@ -438,6 +450,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
const framework::OpKernelType* op_kernel_type_;
// arg_map_fn_ and default_kernel_signature_ may be nullptr
const phi::ArgumentMappingFn* arg_map_fn_;
const phi::KernelSignature* default_kernel_signature_;
};
} // namespace imperative
......
......@@ -107,19 +107,25 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx)
: op_(op),
ctx_(ctx),
kernel_type_(kernel_type),
func_(func),
dev_ctx_(dev_ctx),
pt_kernel_(empty_kernel) {}
arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature),
phi_kernel_(empty_kernel) {}
PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
phi::KernelSignature&& kernel_signature,
const phi::Kernel& phi_kernel,
platform::DeviceContext* dev_ctx)
: op_(op),
ctx_(ctx),
......@@ -127,8 +133,10 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
func_(nullptr),
dev_ctx_(dev_ctx),
run_phi_kernel_(true),
pt_kernel_signature_(std::move(kernel_signature)),
pt_kernel_(pt_kernel) {}
arg_map_fn_(arg_map_fn),
default_kernel_signature_(default_kernel_signature),
kernel_signature_(std::move(kernel_signature)),
phi_kernel_(phi_kernel) {}
template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
......@@ -161,7 +169,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
framework::KernelSignature pt_kernel_signature;
const phi::KernelSignature* default_kernel_signature = nullptr;
phi::KernelSignature kernel_signature;
phi::KernelKey pt_kernel_key;
std::string pt_kernel_name;
#if defined(PADDLE_WITH_XPU)
......@@ -179,20 +188,20 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) {
has_phi_kernel = true;
pt_kernel_signature = (*arg_map_fn)(
kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
const auto* kernel_sig =
default_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
if (kernel_sig) {
if (default_kernel_signature) {
has_phi_kernel = true;
pt_kernel_signature = *kernel_sig;
kernel_signature = *default_kernel_signature;
}
}
if (has_phi_kernel) {
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
VLOG(6) << kernel_signature;
pt_kernel_name = kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
......@@ -230,24 +239,25 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel(
auto& phi_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid()
if (phi_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&& !is_xpu_unsupport
#endif
) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << pt_kernel;
<< " | kernel: " << phi_kernel;
if (expected_kernel_key.place_ != place) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx);
return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn,
default_kernel_signature, std::move(kernel_signature),
phi_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
......@@ -295,9 +305,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn,
default_kernel_signature, std::move(kernel_signature),
pt_cpu_kernel, cpu_ctx);
}
}
}
......@@ -389,7 +399,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
}
return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
dev_ctx);
arg_map_fn, default_kernel_signature, dev_ctx);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
......@@ -425,6 +435,8 @@ static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
......@@ -436,7 +448,8 @@ static void PreparedOpRunImpl(
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type,
arg_map_fn, default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx);
}
......@@ -483,17 +496,19 @@ template <typename VarType>
static void PreparedOpRunPtImpl(
const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
const phi::KernelSignature& kernel_signature, const phi::Kernel& phi_kernel,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
{
platform::RecordEvent record_event(op.Type() + "::infer_shape",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type,
arg_map_fn, default_kernel_signature);
op.Info().infer_shape_(&infer_shape_ctx);
}
......@@ -502,14 +517,14 @@ static void PreparedOpRunPtImpl(
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
PreparePhiData<VarType>(pt_kernel, pt_kernel_signature, ins);
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext pt_kernel_context;
BuildDygraphPhiKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
BuildDygraphPhiKernelContext<VarType>(kernel_signature, phi_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
&pt_kernel_context);
pt_kernel(&pt_kernel_context);
phi_kernel(&pt_kernel_context);
}
if (FLAGS_check_nan_inf) {
......@@ -535,12 +550,14 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, dev_ctx_, ins, outs, attrs,
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, arg_map_fn_,
default_kernel_signature_, kernel_signature_,
phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs);
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, arg_map_fn_,
default_kernel_signature_, dev_ctx_, ins, outs,
attrs, default_attrs);
}
}
......@@ -550,11 +567,13 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs);
op_, kernel_type_, arg_map_fn_, default_kernel_signature_,
kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs);
PreparedOpRunImpl<VariableWrapper>(
op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_,
dev_ctx_, ins, outs, attrs, default_attrs);
}
}
......@@ -564,12 +583,13 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
const framework::AttributeMap& default_attrs) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<egr::EagerVariable>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs);
op_, kernel_type_, arg_map_fn_, default_kernel_signature_,
kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
} else {
PreparedOpRunImpl<egr::EagerVariable>(op_, ctx_, kernel_type_, func_,
dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunImpl<egr::EagerVariable>(
op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_,
dev_ctx_, ins, outs, attrs, default_attrs);
}
}
......
......@@ -150,13 +150,17 @@ class PreparedOp {
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func,
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
platform::DeviceContext* dev_ctx);
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
const phi::ArgumentMappingFn* arg_map_fn,
const phi::KernelSignature* default_kernel_signature,
phi::KernelSignature&& kernel_signature,
const phi::Kernel& phi_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
......@@ -206,8 +210,10 @@ class PreparedOp {
// we may polish the implementation here
bool run_phi_kernel_{false};
bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_;
const phi::Kernel& pt_kernel_;
const phi::ArgumentMappingFn* arg_map_fn_;
const phi::KernelSignature* default_kernel_signature_;
phi::KernelSignature kernel_signature_;
const phi::Kernel& phi_kernel_;
};
const inline framework::Attribute& GetAttr(
......@@ -226,21 +232,23 @@ const inline framework::Attribute& GetAttr(
}
template <typename VarType>
void BuildDygraphPhiKernelContext(
const framework::KernelSignature& pt_kernel_signature,
const phi::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) {
void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
const phi::Kernel& phi_kernel,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx,
phi::KernelContext* kernel_ctx) {
kernel_ctx->SetDeviceContext(dev_ctx);
const auto& input_names = pt_kernel_signature.input_names;
const auto& attr_names = pt_kernel_signature.attr_names;
const auto& output_names = pt_kernel_signature.output_names;
const auto& input_names = kernel_signature.input_names;
const auto& attr_names = kernel_signature.attr_names;
const auto& output_names = kernel_signature.output_names;
auto& input_defs = pt_kernel.args_def().input_defs();
auto& output_defs = pt_kernel.args_def().output_defs();
auto& attr_defs = pt_kernel.args_def().attribute_defs();
auto& input_defs = phi_kernel.args_def().input_defs();
auto& output_defs = phi_kernel.args_def().output_defs();
auto& attr_defs = phi_kernel.args_def().attribute_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
......@@ -286,7 +294,7 @@ void BuildDygraphPhiKernelContext(
"Can not find input variable '%s' for %s OP, please check whether "
"the name setting in OpArgumentMapping is consistent with that in "
"OpMaker.",
input_names[i], pt_kernel_signature.name));
input_names[i], kernel_signature.name));
}
}
......@@ -568,11 +576,11 @@ void BuildDygraphPhiKernelContext(
}
template <typename VarType>
void PreparePhiData(const phi::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature,
void PreparePhiData(const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_signature,
const NameVarMap<VarType>& ins) {
const auto& input_names = pt_kernel_signature.input_names;
auto& input_defs = pt_kernel.args_def().input_defs();
const auto& input_names = kernel_signature.input_names;
auto& input_defs = phi_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册