diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 6deebe93dcc629f494bbf0d98584a8dac2f42e97..d7a2a42ca7dc751f8a6834ef4b3e53e2e0467523 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -402,11 +402,11 @@ std::vector CompatInferMetaContext::MutableOutputBetween( CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, const std::string& op_type) { // 1. get kernel args - auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); + auto* arg_map_fn = ctx->GetPhiArgumentMappingFn(); InferShapeArgumentMappingContext arg_map_context(*ctx); - KernelSignature signature = - arg_map_fn ? (*arg_map_fn)(arg_map_context) - : phi::DefaultKernelSignatureMap::Instance().Get(op_type); + phi::KernelSignature signature = arg_map_fn + ? (*arg_map_fn)(arg_map_context) + : *ctx->GetPhiDefaultKernelSignature(); VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; // 2. build infermeta context diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 3c2395d4320a17edb80fc2308f0bb3e554d470ed..0164c4530764906e02dc1197eff0b6162a763305 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -393,6 +393,16 @@ void InterpretercoreInferShapeContext::SetOutputsDim( SetDims(vars, dims); } +const phi::ArgumentMappingFn* +InterpretercoreInferShapeContext::GetPhiArgumentMappingFn() const { + return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type()); +} + +const phi::KernelSignature* +InterpretercoreInferShapeContext::GetPhiDefaultKernelSignature() const { + return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type()); +} + void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) { can_skip_lod_ = skip; } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 28b9f6f0130f5b2fbd209d6a46ee95be544a5877..83eaf9514a1368d7f38e5adfbeab662b2f5b8aca 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -111,6 +111,10 @@ class InterpretercoreInferShapeContext : public InferShapeContext { void SetOutputsDim(const std::string& name, const std::vector& dims) override; + const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override; + + const phi::KernelSignature* GetPhiDefaultKernelSignature() const override; + void SetSkipLoD(bool skip); protected: diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index d27bf0e150f9785916265556b59d285999344a81..4ef1d3a83a2678bafdd1a722e36503938f44f761 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -271,6 +271,14 @@ class CompileTimeInferShapeContext : public InferShapeContext { SetDims(names, dims); } + const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override { + return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type()); + } + + const phi::KernelSignature *GetPhiDefaultKernelSignature() const override { + return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type()); + } + protected: std::vector GetVarTypes( const std::vector &names) const { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 945b8a89848b115a855294d6c6c5b75bfcd03a16..140103b10592fdfdee95a2ba8d03d12d7880aa5a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1005,6 +1005,14 @@ class RuntimeInferShapeContext : public InferShapeContext { SetDims(vars, dims); } + const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override { + return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type()); + } + + const phi::KernelSignature* GetPhiDefaultKernelSignature() const override { + return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type()); + } + protected: DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL( @@ -1277,16 +1285,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope, phi::KernelKey pt_kernel_key; std::string pt_kernel_name; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { - if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { - pt_kernel_signature_.reset( - new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx)))); - VLOG(6) << *pt_kernel_signature_.get(); + if (kernel_signature_ == nullptr || pt_kernel_ == nullptr) { + kernel_signature_.reset(new phi::KernelSignature( + std::move(GetExpectedPhiKernelArgs(exe_ctx)))); + VLOG(6) << *kernel_signature_.get(); kernel_type_.reset( new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); dev_ctx = pool.Get(kernel_type_->place_); - pt_kernel_name = pt_kernel_signature_->name; + pt_kernel_name = kernel_signature_->name; pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); pt_kernel_.reset( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( @@ -1301,7 +1309,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, << "` not found."; } } else { - pt_kernel_name = pt_kernel_signature_->name; + pt_kernel_name = kernel_signature_->name; // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // But the default library_type is Plain, so we need to modify the // library_type here, otherwise it can't work. @@ -1447,8 +1455,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, phi::KernelContext pt_kernel_context; // Do data transform before building KernelContext // TODO(zhiqiu): support TransferInplaceVarsBack - PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_, - runtime_ctx); + PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx); BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); (*pt_kernel_)(&pt_kernel_context); } else { @@ -1543,14 +1550,14 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( phi::KernelKey OperatorWithKernel::ChoosePhiKernel( const ExecutionContext& ctx) const { - pt_kernel_signature_.reset( - new KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); - VLOG(6) << *pt_kernel_signature_.get(); + kernel_signature_.reset( + new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); + VLOG(6) << *kernel_signature_.get(); kernel_type_.reset( new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); - auto pt_kernel_name = pt_kernel_signature_->name; + auto pt_kernel_name = kernel_signature_->name; auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( pt_kernel_name, pt_kernel_key))); @@ -2151,7 +2158,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( tensor.layout()); } -KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( +phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( const ExecutionContext& ctx) const { ExecutionArgumentMappingContext arg_mapping_ctx(ctx); if (arg_map_fn_ == nullptr) { @@ -2159,8 +2166,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( if (arg_map_fn) { arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn)); } else { - auto func = - [this](const phi::ArgumentMappingContext& ctx) -> KernelSignature { + auto func = [this]( + const phi::ArgumentMappingContext& ctx) -> phi::KernelSignature { return phi::DefaultKernelSignatureMap::Instance().Get(type_); }; arg_map_fn_.reset(new phi::ArgumentMappingFn(func)); @@ -2171,7 +2178,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( Scope* OperatorWithKernel::PreparePhiData( const Scope& scope, const phi::Kernel& pt_kernel, - const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const { + const phi::KernelSignature& pt_kernel_signature, + RuntimeContext* ctx) const { const auto& input_names = pt_kernel_signature.input_names; auto input_defs = pt_kernel.args_def().input_defs(); PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), @@ -2269,9 +2277,9 @@ void OperatorWithKernel::BuildPhiKernelContext( phi::KernelContext* pt_kernel_context) const { pt_kernel_context->SetDeviceContext(dev_ctx); - auto& input_names = pt_kernel_signature_->input_names; - auto& attr_names = pt_kernel_signature_->attr_names; - auto& output_names = pt_kernel_signature_->output_names; + auto& input_names = kernel_signature_->input_names; + auto& attr_names = kernel_signature_->attr_names; + auto& output_names = kernel_signature_->output_names; auto input_defs = pt_kernel_->args_def().input_defs(); auto attr_defs = pt_kernel_->args_def().attribute_defs(); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index dd21be12f4abf5bb559241ba71e8f1871327df0a..70e9f5c1b1457e913dd280cf40406209a238381e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -632,7 +632,7 @@ class OperatorWithKernel : public OperatorBase { phi::KernelContext* pt_kernel_context) const; phi::KernelSignature* PhiKernelSignature() const { - return pt_kernel_signature_.get(); + return kernel_signature_.get(); } phi::Kernel* PhiKernel() const { return pt_kernel_.get(); } @@ -704,7 +704,7 @@ class OperatorWithKernel : public OperatorBase { // we may polish the implementation here mutable bool run_phi_kernel_ = false; mutable bool run_kp_kernel = false; - mutable std::unique_ptr pt_kernel_signature_; + mutable std::unique_ptr kernel_signature_; mutable std::unique_ptr pt_kernel_; mutable std::unique_ptr arg_map_fn_; }; diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 75bab0594758b3013eca0dee82201b5615e3e183..fe7c56827612cafd843eac8eccfd1e902d39950d 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -45,7 +45,7 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { const paddle::SmallVector& GetOutputArgsNames() override; const paddle::SmallVector& GetAttrsArgsNames() override; - KernelSignature GetKernelSignature(); + phi::KernelSignature GetKernelSignature(); private: DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto); @@ -221,10 +221,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { return attr_names_; } -KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { - return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(), - GetInputArgsNames(), GetAttrsArgsNames(), - GetOutputArgsNames()); +phi::KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { + return phi::KernelSignature( + phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(), + GetAttrsArgsNames(), GetOutputArgsNames()); } std::once_flag kernel_sig_map_init_flag; diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index 392a3f9b06b3c11232b5804acd3acefb6a06c59b..a99abbf0cebbbf3648cb2d61e32b02a86e206d4e 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -40,8 +40,6 @@ limitations under the License. */ namespace paddle { namespace framework { -using KernelSignature = phi::KernelSignature; - /* Kernel Key translate */ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index bf9731bafce6405421602967317260b247e0698d..4600213596e62d886d48acddc630309d7819c54e 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -113,6 +113,10 @@ class InferShapeContext { virtual paddle::SmallVector GetOutputVarPtrs(const std::string &name) const = 0; + virtual const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const = 0; + + virtual const phi::KernelSignature *GetPhiDefaultKernelSignature() const = 0; + protected: virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 5b63334c9ea99d0fc6f52339e8fdfcf8c789ee79..8a5d942e059c024c5a5a0c51e74b01d4931f1ba4 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -37,13 +37,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const NameVarMap* in, const NameVarMap* out, const framework::AttributeMap* attr, const framework::AttributeMap* default_attr, const std::string op_type, - const framework::OpKernelType* op_kernel_type = nullptr) + const framework::OpKernelType* op_kernel_type = nullptr, + const phi::ArgumentMappingFn* arg_map_fn = nullptr, + const phi::KernelSignature* default_kernel_signature = nullptr) : var_map_in_(in), var_map_out_(out), attrs_(attr), default_attrs_(default_attr), op_type_(op_type), - op_kernel_type_(op_kernel_type) {} + op_kernel_type_(op_kernel_type), + arg_map_fn_(arg_map_fn), + default_kernel_signature_(default_kernel_signature) {} bool HasInput(const std::string& name) const override { // has only one input @@ -377,6 +381,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext { "SetLoDLevel function not support in dygraph mode")); } + const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override { + return arg_map_fn_; + } + + const phi::KernelSignature* GetPhiDefaultKernelSignature() const override { + return default_kernel_signature_; + } + protected: DDim GetDim(framework::Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( @@ -438,6 +450,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const framework::AttributeMap* default_attrs_; const std::string op_type_; const framework::OpKernelType* op_kernel_type_; + // arg_map_fn_ and default_kernel_signature_ may be nullptr + const phi::ArgumentMappingFn* arg_map_fn_; + const phi::KernelSignature* default_kernel_signature_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index fdeda8aa9701a8bfc78c98e63a5e09af4de345ef..6c056605faa48f10a2e3816cd57e2dd08053d0f0 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -107,19 +107,25 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, + const phi::ArgumentMappingFn* arg_map_fn, + const phi::KernelSignature* default_kernel_signature, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), kernel_type_(kernel_type), func_(func), dev_ctx_(dev_ctx), - pt_kernel_(empty_kernel) {} + arg_map_fn_(arg_map_fn), + default_kernel_signature_(default_kernel_signature), + phi_kernel_(empty_kernel) {} PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, - framework::KernelSignature&& kernel_signature, - const phi::Kernel& pt_kernel, + const phi::ArgumentMappingFn* arg_map_fn, + const phi::KernelSignature* default_kernel_signature, + phi::KernelSignature&& kernel_signature, + const phi::Kernel& phi_kernel, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), @@ -127,8 +133,10 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, func_(nullptr), dev_ctx_(dev_ctx), run_phi_kernel_(true), - pt_kernel_signature_(std::move(kernel_signature)), - pt_kernel_(pt_kernel) {} + arg_map_fn_(arg_map_fn), + default_kernel_signature_(default_kernel_signature), + kernel_signature_(std::move(kernel_signature)), + phi_kernel_(phi_kernel) {} template PreparedOp PrepareImpl(const NameVarMap& ins, @@ -161,7 +169,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); - framework::KernelSignature pt_kernel_signature; + const phi::KernelSignature* default_kernel_signature = nullptr; + phi::KernelSignature kernel_signature; phi::KernelKey pt_kernel_key; std::string pt_kernel_name; #if defined(PADDLE_WITH_XPU) @@ -179,20 +188,20 @@ PreparedOp PrepareImpl(const NameVarMap& ins, phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type()); if (arg_map_fn) { has_phi_kernel = true; - pt_kernel_signature = (*arg_map_fn)( + kernel_signature = (*arg_map_fn)( framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); } else { - const auto* kernel_sig = + default_kernel_signature = phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type()); - if (kernel_sig) { + if (default_kernel_signature) { has_phi_kernel = true; - pt_kernel_signature = *kernel_sig; + kernel_signature = *default_kernel_signature; } } if (has_phi_kernel) { - VLOG(6) << pt_kernel_signature; - pt_kernel_name = pt_kernel_signature.name; + VLOG(6) << kernel_signature; + pt_kernel_name = kernel_signature.name; // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // But the default library_type is Plain, so we need to modify the // library_type here, otherwise it can't work. @@ -230,24 +239,25 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); - auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel( + auto& phi_kernel = phi::KernelFactory::Instance().SelectKernel( pt_kernel_name, pt_kernel_key); - if (pt_kernel.IsValid() + if (phi_kernel.IsValid() #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) && !is_xpu_unsupport #endif ) { VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name << " | kernel key: " << pt_kernel_key - << " | kernel: " << pt_kernel; + << " | kernel: " << phi_kernel; if (expected_kernel_key.place_ != place) { dev_ctx = pool.Get(expected_kernel_key.place_); } - return PreparedOp(op, empty_ctx, expected_kernel_key, - std::move(pt_kernel_signature), pt_kernel, dev_ctx); + return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn, + default_kernel_signature, std::move(kernel_signature), + phi_kernel, dev_ctx); } else { VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name << "` not found."; @@ -295,9 +305,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, << " | kernel key: " << pt_cpu_kernel_key << " | kernel: " << pt_cpu_kernel; auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); - return PreparedOp(op, empty_ctx, expected_kernel_key, - std::move(pt_kernel_signature), pt_cpu_kernel, - cpu_ctx); + return PreparedOp(op, empty_ctx, expected_kernel_key, arg_map_fn, + default_kernel_signature, std::move(kernel_signature), + pt_cpu_kernel, cpu_ctx); } } } @@ -389,7 +399,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, } return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second, - dev_ctx); + arg_map_fn, default_kernel_signature, dev_ctx); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, @@ -425,6 +435,8 @@ static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, + const phi::ArgumentMappingFn* arg_map_fn, + const phi::KernelSignature* default_kernel_signature, platform::DeviceContext* dev_ctx, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { @@ -436,7 +448,8 @@ static void PreparedOpRunImpl( platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); DygraphInferShapeContext infer_shape_ctx( - &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type); + &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type, + arg_map_fn, default_kernel_signature); op.Info().infer_shape_(&infer_shape_ctx); } @@ -483,17 +496,19 @@ template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, const framework::OpKernelType& kernel_type, - const framework::KernelSignature& pt_kernel_signature, - const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx, - const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs, + const phi::ArgumentMappingFn* arg_map_fn, + const phi::KernelSignature* default_kernel_signature, + const phi::KernelSignature& kernel_signature, const phi::Kernel& phi_kernel, + platform::DeviceContext* dev_ctx, const NameVarMap& ins, + const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { { platform::RecordEvent record_event(op.Type() + "::infer_shape", platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); DygraphInferShapeContext infer_shape_ctx( - &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type); + &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type, + arg_map_fn, default_kernel_signature); op.Info().infer_shape_(&infer_shape_ctx); } @@ -502,14 +517,14 @@ static void PreparedOpRunPtImpl( platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - PreparePhiData(pt_kernel, pt_kernel_signature, ins); + PreparePhiData(phi_kernel, kernel_signature, ins); phi::KernelContext pt_kernel_context; - BuildDygraphPhiKernelContext(pt_kernel_signature, pt_kernel, ins, + BuildDygraphPhiKernelContext(kernel_signature, phi_kernel, ins, outs, attrs, default_attrs, dev_ctx, &pt_kernel_context); - pt_kernel(&pt_kernel_context); + phi_kernel(&pt_kernel_context); } if (FLAGS_check_nan_inf) { @@ -535,12 +550,14 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { - PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, - pt_kernel_, dev_ctx_, ins, outs, attrs, + PreparedOpRunPtImpl(op_, kernel_type_, arg_map_fn_, + default_kernel_signature_, kernel_signature_, + phi_kernel_, dev_ctx_, ins, outs, attrs, default_attrs); } else { - PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, - outs, attrs, default_attrs); + PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, arg_map_fn_, + default_kernel_signature_, dev_ctx_, ins, outs, + attrs, default_attrs); } } @@ -550,11 +567,13 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl( - op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, - outs, attrs, default_attrs); + op_, kernel_type_, arg_map_fn_, default_kernel_signature_, + kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs, + default_attrs); } else { - PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, - ins, outs, attrs, default_attrs); + PreparedOpRunImpl( + op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_, + dev_ctx_, ins, outs, attrs, default_attrs); } } @@ -564,12 +583,13 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& default_attrs) { if (run_phi_kernel_) { PreparedOpRunPtImpl( - op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, - outs, attrs, default_attrs); + op_, kernel_type_, arg_map_fn_, default_kernel_signature_, + kernel_signature_, phi_kernel_, dev_ctx_, ins, outs, attrs, + default_attrs); } else { - PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, - dev_ctx_, ins, outs, attrs, - default_attrs); + PreparedOpRunImpl( + op_, ctx_, kernel_type_, func_, arg_map_fn_, default_kernel_signature_, + dev_ctx_, ins, outs, attrs, default_attrs); } } diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 0e75775e9178390ca150a7d58e28ff55d2095fdf..dedb6a382efa6f0be2d6de9d07c3cd4580d0d453 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -150,13 +150,17 @@ class PreparedOp { const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, + const phi::ArgumentMappingFn* arg_map_fn, + const phi::KernelSignature* default_kernel_signature, platform::DeviceContext* dev_ctx); PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, - framework::KernelSignature&& kernel_signature, - const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); + const phi::ArgumentMappingFn* arg_map_fn, + const phi::KernelSignature* default_kernel_signature, + phi::KernelSignature&& kernel_signature, + const phi::Kernel& phi_kernel, platform::DeviceContext* dev_ctx); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, @@ -206,8 +210,10 @@ class PreparedOp { // we may polish the implementation here bool run_phi_kernel_{false}; bool run_kp_kernel_{false}; - framework::KernelSignature pt_kernel_signature_; - const phi::Kernel& pt_kernel_; + const phi::ArgumentMappingFn* arg_map_fn_; + const phi::KernelSignature* default_kernel_signature_; + phi::KernelSignature kernel_signature_; + const phi::Kernel& phi_kernel_; }; const inline framework::Attribute& GetAttr( @@ -226,21 +232,23 @@ const inline framework::Attribute& GetAttr( } template -void BuildDygraphPhiKernelContext( - const framework::KernelSignature& pt_kernel_signature, - const phi::Kernel& pt_kernel, const NameVarMap& ins, - const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& default_attrs, - platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) { +void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, + const phi::Kernel& phi_kernel, + const NameVarMap& ins, + const NameVarMap& outs, + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, + platform::DeviceContext* dev_ctx, + phi::KernelContext* kernel_ctx) { kernel_ctx->SetDeviceContext(dev_ctx); - const auto& input_names = pt_kernel_signature.input_names; - const auto& attr_names = pt_kernel_signature.attr_names; - const auto& output_names = pt_kernel_signature.output_names; + const auto& input_names = kernel_signature.input_names; + const auto& attr_names = kernel_signature.attr_names; + const auto& output_names = kernel_signature.output_names; - auto& input_defs = pt_kernel.args_def().input_defs(); - auto& output_defs = pt_kernel.args_def().output_defs(); - auto& attr_defs = pt_kernel.args_def().attribute_defs(); + auto& input_defs = phi_kernel.args_def().input_defs(); + auto& output_defs = phi_kernel.args_def().output_defs(); + auto& attr_defs = phi_kernel.args_def().attribute_defs(); PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), platform::errors::InvalidArgument( @@ -286,7 +294,7 @@ void BuildDygraphPhiKernelContext( "Can not find input variable '%s' for %s OP, please check whether " "the name setting in OpArgumentMapping is consistent with that in " "OpMaker.", - input_names[i], pt_kernel_signature.name)); + input_names[i], kernel_signature.name)); } } @@ -568,11 +576,11 @@ void BuildDygraphPhiKernelContext( } template -void PreparePhiData(const phi::Kernel& pt_kernel, - const framework::KernelSignature& pt_kernel_signature, +void PreparePhiData(const phi::Kernel& phi_kernel, + const phi::KernelSignature& kernel_signature, const NameVarMap& ins) { - const auto& input_names = pt_kernel_signature.input_names; - auto& input_defs = pt_kernel.args_def().input_defs(); + const auto& input_names = kernel_signature.input_names; + auto& input_defs = phi_kernel.args_def().input_defs(); PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), platform::errors::InvalidArgument(