From 0f8888360ea4b001d7da88be920fba7d9673c66e Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 6 Dec 2019 09:14:50 -0600 Subject: [PATCH] Polish op registry codes (#21561) * polish infer shape registry, test=develop * modify some operators registry, test=develop --- paddle/fluid/framework/details/op_registry.h | 50 +++++++++++++ .../framework/no_need_buffer_vars_inference.h | 14 ++++ .../no_need_buffer_vars_inference_test.cc | 26 +++++++ paddle/fluid/framework/op_desc.cc | 50 ------------- paddle/fluid/operators/cast_op.cc | 24 +++---- .../fluid/operators/controlflow/compare_op.cc | 16 ++--- .../fluid/operators/controlflow/logical_op.cc | 70 ++++++++++--------- paddle/fluid/operators/expand_as_op.cc | 12 +++- .../fluid/operators/fused/conv_fusion_op.cc | 11 +-- .../operators/ngraph/ngraph_engine_op.cc | 3 +- paddle/fluid/operators/random_crop_op.cc | 43 ++++++------ paddle/fluid/operators/save_op.cc | 7 +- .../sequence_ops/sequence_concat_op.cc | 10 +-- .../operators/tensorrt/tensorrt_engine_op.cc | 2 +- 14 files changed, 195 insertions(+), 143 deletions(-) diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 9bfd47d488..ece0dc4e77 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -155,17 +155,41 @@ class OperatorRegistrarRecursive { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ(info->creator_, nullptr, + platform::errors::AlreadyExists( + "OpCreator of %s has been registered", op_type)); info->creator_ = [](const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) { return new T(type, inputs, outputs, attrs); }; + + if (std::is_base_of::value) { + PADDLE_ENFORCE_EQ( + info->infer_shape_, nullptr, + platform::errors::AlreadyExists( + "Duplicate InferShapeFN of %s has been registered", op_type)); + + auto* op = + dynamic_cast(info->creator_("", {}, {}, {})); + PADDLE_ENFORCE_NOT_NULL(op, platform::errors::InvalidArgument( + "%s should have kernels", op_type)); + info->infer_shape_ = [op](InferShapeContext* ctx) { + op->InferShape(ctx); + }; + } } }; template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ(info->proto_, nullptr, + platform::errors::AlreadyExists( + "OpProto of %s has been registered", op_type)); + PADDLE_ENFORCE_EQ(info->checker_, nullptr, + platform::errors::AlreadyExists( + "OpAttrChecker of %s has been registered", op_type)); info->proto_ = new proto::OpProto; info->checker_ = new OpAttrChecker(); T maker; @@ -181,6 +205,11 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ( + info->grad_op_maker_, nullptr, + platform::errors::AlreadyExists( + "GradOpDescMaker of %s has been registered", op_type)); + info->grad_op_maker_ = []( const OpDesc& fwd_op, const std::unordered_set& no_grad_set, @@ -199,6 +228,11 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ( + info->dygraph_grad_op_maker_, nullptr, + platform::errors::AlreadyExists( + "GradOpBaseMaker of %s has been registered", op_type)); + info->dygraph_grad_op_maker_ = []( const imperative::OpBase* fw_op_base, const imperative::NameVarBaseMap& var_base_map_in, @@ -212,6 +246,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ( + info->infer_var_type_, nullptr, + platform::errors::AlreadyExists( + "VarTypeInference of %s has been registered", op_type)); info->infer_var_type_ = [](InferVarTypeContext* context) { T inference; inference(context); @@ -222,6 +260,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ( + info->infer_shape_, nullptr, + platform::errors::AlreadyExists( + "Duplicate InferShapeFN of %s has been registered", op_type)); info->infer_shape_ = [](InferShapeContext* ctx) { T inference; inference(ctx); @@ -232,6 +274,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ( + info->infer_inplace_, nullptr, + platform::errors::AlreadyExists( + "InplaceOpInference of %s has been registered", op_type)); info->infer_inplace_ = [](bool use_cuda) { T infer; return infer(use_cuda); @@ -242,6 +288,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { + PADDLE_ENFORCE_EQ( + info->infer_no_need_buffer_vars_, nullptr, + platform::errors::AlreadyExists( + "NoNeedBufferVarsInference of %s has been registered", op_type)); info->infer_no_need_buffer_vars_.Reset(std::make_shared()); } }; diff --git a/paddle/fluid/framework/no_need_buffer_vars_inference.h b/paddle/fluid/framework/no_need_buffer_vars_inference.h index c1640f6ccb..843d178cf1 100644 --- a/paddle/fluid/framework/no_need_buffer_vars_inference.h +++ b/paddle/fluid/framework/no_need_buffer_vars_inference.h @@ -124,9 +124,23 @@ class InferNoNeedBufferVarsFN { inferer_ = inferer; } + inline bool operator==(std::nullptr_t) const { return inferer_ == nullptr; } + + inline bool operator!=(std::nullptr_t) const { return inferer_ != nullptr; } + private: std::shared_ptr inferer_; }; +static inline bool operator==(std::nullptr_t, + const InferNoNeedBufferVarsFN &other) { + return other == nullptr; +} + +static inline bool operator!=(std::nullptr_t, + const InferNoNeedBufferVarsFN &other) { + return other != nullptr; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/no_need_buffer_vars_inference_test.cc b/paddle/fluid/framework/no_need_buffer_vars_inference_test.cc index f480dceaa1..98196f9cb6 100644 --- a/paddle/fluid/framework/no_need_buffer_vars_inference_test.cc +++ b/paddle/fluid/framework/no_need_buffer_vars_inference_test.cc @@ -48,5 +48,31 @@ TEST(test_no_need_buffer_vars_inference, test_dygraph) { ASSERT_TRUE(boost::get(ctx.GetAttr("is_test"))); } +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(TestNoNeedBufferVarsInferer, "X1", "X2"); + +TEST(test_no_need_buffer_vars_inference, test_nullptr_comparation) { + InferNoNeedBufferVarsFN infer_fn; + ASSERT_FALSE(static_cast(infer_fn)); + ASSERT_TRUE(!infer_fn); + ASSERT_TRUE(infer_fn == nullptr); + ASSERT_TRUE(nullptr == infer_fn); + ASSERT_FALSE(infer_fn != nullptr); + ASSERT_FALSE(nullptr != infer_fn); + + infer_fn.Reset(std::make_shared()); + ASSERT_TRUE(static_cast(infer_fn)); + ASSERT_FALSE(!infer_fn); + ASSERT_FALSE(infer_fn == nullptr); + ASSERT_FALSE(nullptr == infer_fn); + ASSERT_TRUE(infer_fn != nullptr); + ASSERT_TRUE(nullptr != infer_fn); + + auto no_need_slots = + infer_fn(VariableNameMap{}, VariableNameMap{}, AttributeMap{}); + ASSERT_EQ(no_need_slots.size(), 2UL); + ASSERT_EQ(no_need_slots.count("X1"), 1UL); + ASSERT_EQ(no_need_slots.count("X2"), 1UL); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index c657d7a2bc..69823126ad 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -653,55 +653,6 @@ void OpDesc::Flush() { } } -static std::once_flag init_infer_shape_funcs; - -/** - * NOTE(paddle-dev): Very tricky code here. Maybe we should find a - * better way to register compile-time infershape method gentlely. - * - * Normally, we can register a class derived from InferShapeBase, so that - * we can set the field of `infer_shape_` inside OpInfo when registering op. - * - * However, there is another way we can set the field of `infer_shape_` inside - * OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After - * running the following method InitInferShapeFuncs, `infer_shape_` would be set - * to be the InferShape method of OperatorWithKernel. That is to say, we borrow - * the run-time InferShape method of OperatorWithKernel to be the compile-time - * InferShape method. - * - * However, during compiling time, we may not know inputs, outputs and attrs of - * run-time OperatorWithKernel. So the following code creates a fake - * OperatorWithKernel object. That is why the field info_ of OperatorBase - * would be null. - */ -static void InitInferShapeFuncs() { - std::call_once(init_infer_shape_funcs, [] { - auto &map = OpInfoMap::Instance(); - auto &info_map = *map.mutable_map(); - - for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) { - auto op_type = kern_pair.first; - auto it = info_map.find(op_type); - PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered", - op_type); - auto &op_info = it->second; - if (op_info.infer_shape_) { // infer_shape has been registered. - continue; - } - - auto op = dynamic_cast(op_info.Creator()( - "", VariableNameMap{}, VariableNameMap{}, AttributeMap{})); - - PADDLE_ENFORCE_NOT_NULL( - op, "InferShapeBase is not registered to Operator %s", op_type); - - op_info.infer_shape_ = [op](InferShapeContext *ctx) { - op->InferShape(ctx); - }; - } - }); -} - void OpDesc::CheckAttrs() { PADDLE_ENFORCE(!Type().empty(), "CheckAttr() can not be called before type is setted."); @@ -718,7 +669,6 @@ void OpDesc::CheckAttrs() { void OpDesc::InferShape(const BlockDesc &block) const { try { VLOG(3) << "CompileTime infer shape on " << Type(); - InitInferShapeFuncs(); auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; PADDLE_ENFORCE(static_cast(infer_shape), "%s's infer_shape has not been registered", this->Type()); diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index c7a6b50b2f..e0bd49e7d2 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -38,17 +38,6 @@ the input dtype, but it's fine if you do so. } }; -class CastOpInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInput("X"), "The input of cast op must be set"); - PADDLE_ENFORCE(context->HasOutput("Out"), - "The output of cast op must be set"); - context->SetOutputDim("Out", context->GetInputDim("X")); - context->ShareLoD("X", "Out"); - } -}; - template class CastOpGradMaker : public framework::SingleGradOpMaker { public: @@ -71,6 +60,17 @@ class CastOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ( + context->HasInput("X"), true, + platform::errors::NotFound("The input(X) of cast op must be set")); + PADDLE_ENFORCE_EQ( + context->HasOutput("Out"), true, + platform::errors::NotFound("The output of cast op must be set")); + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); @@ -88,7 +88,7 @@ using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker, ops::CastOpGradMaker, - ops::CastOpInferShape, ops::CastOpProtoMaker); + ops::CastOpProtoMaker); REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index 7b359ff0be..be71319cbb 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -73,9 +73,12 @@ calculated by $%s$ }; template -class CompareOpInferShape : public framework::InferShapeBase { +class CompareOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext* context) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { OpComment comment; PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X", comment.type); @@ -89,13 +92,7 @@ class CompareOpInferShape : public framework::InferShapeBase { context->SetOutputDim("Out", context->GetInputDim("X")); context->ShareLoD("X", "Out"); } -}; -class CompareOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); @@ -118,9 +115,8 @@ class CompareOp : public framework::OperatorWithKernel { char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::equation[]{_equation}; \ REGISTER_OPERATOR( \ - op_type, ::paddle::operators::CompareOp, \ + op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \ ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ - ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker, \ ::paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/controlflow/logical_op.cc b/paddle/fluid/operators/controlflow/logical_op.cc index 95c16ebc94..4b3cdf6818 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cc +++ b/paddle/fluid/operators/controlflow/logical_op.cc @@ -57,10 +57,44 @@ Each element of Out is calculated by %s } }; +class LogicalOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + // LogicalOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + +template +class UnaryLogicalOp : public LogicalOp { + public: + using LogicalOp::LogicalOp; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE_EQ( + context->HasInput("X"), true, + platform::errors::NotFound("Input(X) of %s operator must not be null", + comment.type)); + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + template -class BinaryLogicalOpInferShape : public framework::InferShapeBase { +class BinaryLogicalOp : public LogicalOp { public: - void operator()(framework::InferShapeContext *context) const override { + using LogicalOp::LogicalOp; + + protected: + void InferShape(framework::InferShapeContext *context) const override { OpComment comment; PADDLE_ENFORCE_EQ(context->HasInput("X"), true, "Input(X) of %s operator must not be null", comment.type); @@ -84,32 +118,6 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase { } }; -template -class UnaryLogicalOpInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - OpComment comment; - PADDLE_ENFORCE_EQ(context->HasInput("X"), true, - "Input(X) of %s operator must not be null", comment.type); - context->SetOutputDim("Out", context->GetInputDim("X")); - context->ShareLoD("X", "Out"); - } -}; - -class LogicalOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); - // LogicalOp kernel's device type is decided by input tensor place - kt.place_ = ctx.Input("X")->place(); - return kt; - } -}; - } // namespace operators } // namespace paddle @@ -121,9 +129,8 @@ class LogicalOp : public framework::OperatorWithKernel { char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::equation[]{_equation}; \ REGISTER_OPERATOR( \ - op_type, ::paddle::operators::LogicalOp, \ + op_type, ::paddle::operators::BinaryLogicalOp<_##op_type##Comment>, \ ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \ - ::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker, \ ::paddle::framework::EmptyGradOpMaker); @@ -135,9 +142,8 @@ class LogicalOp : public framework::OperatorWithKernel { char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::equation[]{_equation}; \ REGISTER_OPERATOR( \ - op_type, ::paddle::operators::LogicalOp, \ + op_type, ::paddle::operators::UnaryLogicalOp<_##op_type##Comment>, \ ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \ - ::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker, \ ::paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/expand_as_op.cc b/paddle/fluid/operators/expand_as_op.cc index 002623c1b1..e72a16cb5f 100644 --- a/paddle/fluid/operators/expand_as_op.cc +++ b/paddle/fluid/operators/expand_as_op.cc @@ -88,6 +88,13 @@ class ExpandAsGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, x_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } }; template @@ -108,7 +115,7 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker { } }; -// DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X"); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandAsGradNoNeedBufVarsInferer, "X"); } // namespace operators } // namespace paddle @@ -117,7 +124,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(expand_as, ops::ExpandAsOp, ops::ExpandAsOpMaker, ops::ExpandAsGradOpMaker, ops::ExpandAsGradOpMaker); -REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp); +REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp, + ops::ExpandAsGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL( expand_as, ops::ExpandAsKernel, ops::ExpandAsKernel, diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cc b/paddle/fluid/operators/fused/conv_fusion_op.cc index f4da7ec4dd..b53d7d1865 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cc +++ b/paddle/fluid/operators/fused/conv_fusion_op.cc @@ -59,9 +59,12 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker { } }; -class Conv2DFusionOpInferShape : public framework::InferShapeBase { +class Conv2DFusionOp : public operators::ConvOp { public: - void operator()(framework::InferShapeContext* ctx) const override { + using operators::ConvOp::ConvOp; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input(Input) of ConvOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, @@ -175,7 +178,7 @@ class Conv2DFusionOpInferShape : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR( - conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker, - ops::Conv2DFusionOpInferShape, ops::ConvOpInferVarType, + conv2d_fusion, ops::Conv2DFusionOp, ops::Conv2DFusionOpMaker, + ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/ngraph/ngraph_engine_op.cc b/paddle/fluid/operators/ngraph/ngraph_engine_op.cc index 479c95ba08..621f1a3d8c 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine_op.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine_op.cc @@ -45,8 +45,7 @@ class NgraphEngineInferVarType : public framework::VarTypeInference { namespace ops = paddle::operators; -REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker, - ops::NgraphEngineOpMaker); +REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker); REGISTER_OP_CPU_KERNEL( ngraph_engine, ops::NgraphEngineKernel); diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index acb6ca84ca..f12ea0275d 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -20,6 +20,29 @@ class RandomCropOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + auto shape = ctx->Attrs().Get>("shape"); + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GT( + x_dim.size(), static_cast(shape.size()), + platform::errors::InvalidArgument( + "Rank of Input(X) must be equal to length of Attr(shape)")); + auto out_dim = framework::vectorize(x_dim); + for (size_t i = 1; i <= shape.size(); ++i) { + size_t x_i = x_dim.size() - i; + size_t shape_i = shape.size() - i; + if (ctx->IsRuntime() || (x_dim[x_i] > 0 && shape[shape_i] > 0)) { + PADDLE_ENFORCE_GE( + x_dim[x_i], shape[shape_i], + platform::errors::InvalidArgument( + "Size of Input(X) must be larger than Attr(shape)")); + } + out_dim[x_i] = shape[shape_i]; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); + } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -51,25 +74,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class RandomCropOpInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override { - auto shape = ctx->Attrs().Get>("shape"); - auto x_dim = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GT(x_dim.size(), static_cast(shape.size())); - auto out_dim = framework::vectorize(x_dim); - for (size_t i = 1; i <= shape.size(); ++i) { - size_t x_i = x_dim.size() - i; - size_t shape_i = shape.size() - i; - if (ctx->IsRuntime() || (x_dim[x_i] > 0 && shape[shape_i] > 0)) { - PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]); - } - out_dim[x_i] = shape[shape_i]; - } - ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); - } -}; - } // namespace operators } // namespace paddle @@ -77,7 +81,6 @@ namespace ops = paddle::operators; namespace f = paddle::framework; REGISTER_OPERATOR( random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, - ops::RandomCropOpInferShape, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 73bac5c2fd..09b171fe90 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -77,18 +77,13 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { } }; -class SaveOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override {} -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker, - ops::SaveOpVarTypeInference, ops::SaveOpShapeInference); + ops::SaveOpVarTypeInference); REGISTER_OP_CPU_KERNEL( save, ops::SaveOpKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index 319fdec06a..ac513dee29 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -35,9 +35,12 @@ class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class SeqConcatShapeInferer : public framework::InferShapeBase { +class SequenceConcatOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *context) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInputs("X"), "Input(X) of Sequence Concat Op should not be null."); PADDLE_ENFORCE(context->HasOutput("Out"), @@ -117,8 +120,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference, namespace op = paddle::operators; -REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, - op::SeqConcatOpMaker, op::SeqConcatShapeInferer, +REGISTER_OPERATOR(sequence_concat, op::SequenceConcatOp, op::SeqConcatOpMaker, op::SeqConcatGradOpMaker, op::SeqConcatGradOpMaker); template diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc index 6cf3e65e00..708fccf971 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc @@ -55,6 +55,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, - ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); + ops::TensorRTEngineOpMaker); #endif // PADDLE_WITH_CUDA -- GitLab