diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 57fb68e80427afa56372bebb31ff5822135858b6..7232a707916dd5f0795c04cff8137c5e88132d42 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -381,6 +381,10 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, std::type_index(typeid(std::vector))) { infer_meta_context.EmplaceBackAttr(std::move( phi::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr(std::move( + phi::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); } else if (std::type_index(attr.type()) == std::type_index(typeid(int))) { infer_meta_context.EmplaceBackAttr( diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index e23342ebb5dc7639d68500964bfdfbd099d077cd..6baa504562e76fdc2a62f885c4d1c1b1a5629a8e 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/empty_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/infermeta/nullary.h" + namespace paddle { namespace operators { @@ -51,46 +53,6 @@ class EmptyOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* context) const override { - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "empty"); - - if (context->HasInput("ShapeTensor")) { - auto shape_dims = context->GetInputDim("ShapeTensor"); - int num_ele = 1; - for (int i = 0; i < shape_dims.size(); ++i) { - num_ele *= shape_dims[i]; - } - auto vec_dims = std::vector(num_ele, -1); - context->SetOutputDim("Out", phi::make_ddim(vec_dims)); - } else if (context->HasInputs("ShapeTensorList")) { - std::vector out_dims; - auto dims_list = context->GetInputsDim("ShapeTensorList"); - for (size_t i = 0; i < dims_list.size(); ++i) { - auto& dims = dims_list[i]; - PADDLE_ENFORCE_EQ(dims, phi::make_ddim({1}), - platform::errors::InvalidArgument( - "The shape of Tensor in list must be [1]. " - "But received the shape is [%s]", - dims)); - - out_dims.push_back(-1); - } - - context->SetOutputDim("Out", phi::make_ddim(out_dims)); - } else { - auto& shape = context->Attrs().Get>("shape"); - for (size_t i = 0; i < shape.size(); ++i) { - PADDLE_ENFORCE_GE( - shape[i], 0, - platform::errors::InvalidArgument( - "Each value of attribute 'shape' is expected to be no less " - "than 0. But recieved: shape[%u] = %d; shape = [%s].", - i, shape[i], phi::make_ddim(shape))); - } - context->SetOutputDim("Out", phi::make_ddim(shape)); - } - } - protected: framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const framework::Tensor& tensor, @@ -126,14 +88,11 @@ class EmptyOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; namespace plat = paddle::platform; +DELCARE_INFER_SHAPE_FUNCTOR(empty, EmptyInferShapeFunctor, + PT_INFER_META(phi::CreateInferMeta)); + REGISTER_OPERATOR( empty, ops::EmptyOp, ops::EmptyOpMaker, ops::EmptyOpVarTypeInference, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL(empty, ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel); + paddle::framework::EmptyGradOpMaker, + EmptyInferShapeFunctor); diff --git a/paddle/fluid/operators/empty_op.cu.cc b/paddle/fluid/operators/empty_op.cu.cc deleted file mode 100644 index 22799e507aeff7940274f729b174f50bfd9132a5..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/empty_op.cu.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/empty_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - empty, ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel, - ops::EmptyKernel); diff --git a/paddle/fluid/operators/empty_op.h b/paddle/fluid/operators/empty_op.h deleted file mode 100644 index cb466fffcd7c7358b6e84c18b7895a17b2eaa907..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/empty_op.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/utils.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class EmptyKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto dtype = static_cast( - context.Attr("dtype")); - - Tensor *out_tensor = context.Output("Out"); - - auto shape = GetShape(context); - out_tensor->Resize(shape); - - out_tensor->mutable_data(context.GetPlace(), - framework::TransToPhiDataType(dtype)); - } -}; - -} // namespace operators -} // namespace paddle