diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 4a625534909864fcf9153d160c7b66f91f82b256..aae36cf455dfee028b18050bdf431ee4601c479e 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/op_utils.h" @@ -54,7 +55,12 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { } size_t InputSize(const std::string& name) const override { - return ctx_.Inputs(name).size(); + if (ctx_.HasInputs(name)) { + return ctx_.Inputs(name).size(); + } else if (ctx_.HasInput(name)) { + return 1; + } + return 0; } size_t OutputSize(const std::string& name) const override { @@ -288,6 +294,16 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, auto& attr_names = std::get<1>(signature.args); auto& output_names = std::get<2>(signature.args); + auto kernels_map = + phi::KernelFactory::Instance().SelectKernelMap(signature.name); + if (kernels_map.size() == 0) { + PADDLE_THROW( + platform::errors::Unimplemented("Not find `%s` kernels when construct " + "InferMetaContext.", + signature.name)); + } + auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs(); + // TODO(chenweihang): support multiple inputs and outputs later phi::InferMetaContext infer_mete_context; for (auto& in_name : input_names) { @@ -299,9 +315,70 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } + for (auto& out_name : output_names) { + if (ctx->HasOutput(out_name)) { + infer_meta_context.EmplaceBackOutput(std::make_shared( + ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); + } else { + infer_meta_context.EmplaceBackOutput({nullptr}); + } + } auto attr_reader = ctx->Attrs(); - for (auto& attr_name : attr_names) { - if (ctx->HasAttr(attr_name)) { + for (size_t i = 0; i < attr_names.size(); ++i) { + auto attr_name = attr_names[i]; + if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { + // When attr is a vector_tensor or tensor, transform it to ScalarArray + if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { + const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); + if (ctx->IsRuntime()) { + // If is in runtime, we will get tensor's value for ScalarArray + // and push it into attrs + std::vector vars; + vars.reserve(infershape_inputs.size()); + for (size_t i = 0; i < infershape_inputs.size(); i++) { + vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i])); + } + if (infershape_inputs.size() != 1) { + infer_meta_context.EmplaceBackAttr( + std::move(experimental::MakePtenScalarArrayFromVarList(vars))); + } else { + infer_meta_context.EmplaceBackAttr( + std::move(experimental::MakePtenScalarArrayFromVar(*vars[0]))); + } + } else { + // If is not in runtime, we will set default value(-1) for ScalarArray + int64_t num_ele = 1; + std::vector vars; + vars.reserve(infershape_inputs.size()); + for (size_t i = 0; i < infershape_inputs.size(); i++) { + vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); + } + for (auto& var : vars) { + const auto& tensor_dims = var->GetShape(); + for (size_t i = 0; i < tensor_dims.size(); ++i) { + num_ele *= tensor_dims[i]; + } + } + phi::ScalarArray tensor_attr(std::vector(num_ele, -1)); + tensor_attr.SetFromTensor(true); + infer_meta_context.EmplaceBackAttr(std::move(tensor_attr)); + } + } else if (ctx->HasAttr(attr_name)) { + auto& attr = attr_reader.GetAttr(attr_name); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr(std::move( + phi::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to ScalarArray when " + "construct KernelContext.", + attr_name)); + } + } + + } else if (ctx->HasAttr(attr_name)) { + // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); @@ -345,17 +422,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, "Unsupported attribute type is received when call " "InferShapeFunctor.")); } - } else { - // do nothing - } - } - - for (auto& out_name : output_names) { - if (ctx->HasOutput(out_name)) { - infer_meta_context.EmplaceBackOutput(std::make_shared( - ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); - } else { - infer_meta_context.EmplaceBackOutput({nullptr}); } } diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/paddle/fluid/framework/infershape_utils_test.cc index 2554031a91859b972f3afbb7d2527eacac499568..592e787109d18c45eb872fb720954ed29b073ea4 100644 --- a/paddle/fluid/framework/infershape_utils_test.cc +++ b/paddle/fluid/framework/infershape_utils_test.cc @@ -23,8 +23,11 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/kernel_registry.h" namespace paddle { namespace framework { @@ -93,6 +96,17 @@ phi::KernelSignature InferShapeUtilsTestOpArgumentMapping( {}); } +template +void InferShapeUtilsTestKernel( + const Context& dev_ctx, const phi::DenseTensor& x, bool attr1, int attr2, + int64_t attr3, float attr4, const std::string& attr5, + const std::vector& attr6, const std::vector& attr7, + const std::vector& attr8, const std::vector& attr9, + const std::vector& attr10, const std::vector& attr11, + phi::DenseTensor* out) { + VLOG(6) << "Come into InferShapeUtilsTestKernel"; +} + } // namespace framework } // namespace paddle @@ -104,6 +118,9 @@ REGISTER_OPERATOR(infer_shape_utils_test, paddle::framework::InferShapeUtilsTestOpMaker, InferShapeUtilsTestInferShapeFunctor); +PT_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT, + paddle::framework::InferShapeUtilsTestKernel, int) {} + TEST(InferShapeUtilsTest, ALL) { paddle::framework::ProgramDesc prog; paddle::framework::proto::BlockDesc proto_block; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 5620988545a0f1274176cd888821746bc9e722c7..ddb598f575f6737f7c7d4336eeee866b12c12fb1 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/pten_utils.h" @@ -21,8 +22,11 @@ limitations under the License. */ #include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/reshape_grad_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" + namespace paddle { namespace framework { class InferShapeContext; @@ -472,22 +476,6 @@ class Reshape2Op : public ReshapeOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ReshapeOp(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true, - platform::errors::InvalidArgument( - "Output(XShape) of ReshapeOp should not be null.")); - const auto &x_dims = ctx->GetInputDim("X"); - std::vector xshape_dims(x_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < x_dims.size(); ++i) { - xshape_dims[i + 1] = x_dims[i]; - } - ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims)); - ctx->ShareLoD("X", /*->*/ "XShape"); - - ReshapeOp::InferShape(ctx); - } }; class Reshape2OpMaker : public ReshapeOpMaker { @@ -647,10 +635,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel); + +DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor, + PT_INFER_META(phi::ReshapeWithXShapeInferMeta)); + REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, ops::Reshape2GradMaker, ops::Reshape2GradMaker, - ops::ReshapeOpInplaceInferer); + ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ops::Reshape2DoubleGradMaker, ops::Reshape2DoubleGradMaker, diff --git a/paddle/phi/api/lib/utils/tensor_utils.cc b/paddle/phi/api/lib/utils/tensor_utils.cc index 6231922fdbafac55c928dae3669b0c57cb71dd06..fc56d201fe3ccc736fdef834e69426e5f0384bf9 100644 --- a/paddle/phi/api/lib/utils/tensor_utils.cc +++ b/paddle/phi/api/lib/utils/tensor_utils.cc @@ -131,7 +131,7 @@ phi::ScalarArray MakePtenScalarArrayFromVarList( } phi::ScalarArray result{vector_data}; - result.setInitByTensor(true); + result.SetFromTensor(true); return result; } diff --git a/paddle/phi/common/scalar.h b/paddle/phi/common/scalar.h index 092e05e95979a0f802882ce4e73233f9ff36c80c..1da77a0fa196413436030fc2864514cc222af6f8 100644 --- a/paddle/phi/common/scalar.h +++ b/paddle/phi/common/scalar.h @@ -25,7 +25,7 @@ namespace experimental { template class ScalarBase { public: - bool IsInitByTensor() const { return is_init_by_tensor_; } + bool FromTensor() const { return is_from_tensor_; } // Constructor support implicit ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT data_.f64 = val; @@ -104,7 +104,7 @@ class ScalarBase { // The Tensor must have one dim ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT - is_init_by_tensor_ = true; + is_from_tensor_ = true; PD_CHECK( tensor.numel() == 1, "The Scalar only supports Tensor with 1 element, but now Tensor has `", @@ -196,7 +196,7 @@ class ScalarBase { friend void CopyScalar(const ScalarBase& src, ScalarBase* dst); private: - bool is_init_by_tensor_{false}; + bool is_from_tensor_{false}; DataType dtype_; union data { bool b; diff --git a/paddle/phi/common/scalar_array.h b/paddle/phi/common/scalar_array.h index 522228ba99e0b5273a7b939484c863077950d626..39284095961a727d7d0052a589b543df31bd6ebc 100644 --- a/paddle/phi/common/scalar_array.h +++ b/paddle/phi/common/scalar_array.h @@ -43,13 +43,13 @@ class ScalarArrayBase { AssignData(date_value, n); } - bool IsInitByTensor() const { return is_init_by_tensor_; } + bool FromTensor() const { return is_from_tensor_; } - void setInitByTensor(bool val) { is_init_by_tensor_ = val; } + void SetFromTensor(bool val) { is_from_tensor_ = val; } // The Tensor must have one dim ScalarArrayBase(const T& tensor) { // NOLINT - is_init_by_tensor_ = true; + is_from_tensor_ = true; size_t n = tensor.numel(); array_.reserve(n); switch (tensor.dtype()) { @@ -71,7 +71,7 @@ class ScalarArrayBase { // The Tensor in vec must have only one element ScalarArrayBase(const std::vector& tensor_list) { // NOLINT - is_init_by_tensor_ = true; + is_from_tensor_ = true; for (size_t i = 0; i < tensor_list.size(); ++i) { DataType data_type = tensor_list[i].dtype(); @@ -117,7 +117,7 @@ class ScalarArrayBase { // TODO(zhangyunfei) Replace std::vector with a more efficient container // structure. std::vector array_; - bool is_init_by_tensor_{false}; + bool is_from_tensor_{false}; }; using ScalarArray = diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index a9c064f1b8896b6c782c29a7595a28aba1223784..8c7d096eab0916d984819cfe85810a90cd29e631 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -241,6 +241,10 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 2f01174dff9b34c56f3c59d861ca25d0ffbbc4f5..4b13545e038f0970c5ed60ca3c4fefaeb6edba58 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/phi/infermeta/unary.h" #include - #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/infermeta_utils.h" namespace phi { @@ -217,7 +217,7 @@ void InferMetaFromVecValue(const MetaTensor& x, MetaTensor* out) { PADDLE_ENFORCE_EQ(!shape.empty(), true, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The parameter 'shape' in ReshapeOp must be set. " "But received 'shape' is empty.")); auto x_dims = x.dims(); @@ -234,8 +234,42 @@ void InferMetaFromVecValue(const MetaTensor& x, void ReshapeInferMeta(const MetaTensor& x, const ScalarArray& shape, - MetaTensor* out) { - InferMetaFromVecValue(x, shape.GetData(), out); + MetaTensor* out, + MetaConfig config) { + auto& shape_data = shape.GetData(); + PADDLE_ENFORCE_NOT_NULL(out, + phi::errors::InvalidArgument( + "Output(Out) of ReshapeOp should not be null.")); + if (!config.is_runtime && shape.FromTensor()) { + out->set_dims(phi::make_ddim(shape_data)); + out->share_lod(x); + return; + } + PADDLE_ENFORCE_GT(shape_data.size(), + 0, + phi::errors::InvalidArgument( + "The shape's size in ReshapeOp can't be zero.")); + InferMetaFromVecValue(x, shape_data, out); +} + +void ReshapeWithXShapeInferMeta(const MetaTensor& x, + const ScalarArray& shape, + MetaTensor* xshape, + MetaTensor* out, + MetaConfig config) { + PADDLE_ENFORCE_NOT_NULL( + xshape, + phi::errors::InvalidArgument( + "Output(XShape) of ReshapeOp should not be null.")); + const auto& x_dims = x.dims(); + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < x_dims.size(); ++i) { + xshape_dims[i + 1] = x_dims[i]; + } + xshape->set_dims(phi::make_ddim(xshape_dims)); + xshape->share_lod(x); + ReshapeInferMeta(x, shape, out, config); } /* Why not use ReduceInferMeta directly? diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 560ce0d2d4c489fb4537b426c6ca45a1407a2853..2ab425d42cd33ec49adf704a54e85e6f1714e19c 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -54,7 +54,14 @@ void InferMetaFromVecValue(const MetaTensor& x, void ReshapeInferMeta(const MetaTensor& x, const ScalarArray& shape, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); + +void ReshapeWithXShapeInferMeta(const MetaTensor& x, + const ScalarArray& shape, + MetaTensor* xshape, + MetaTensor* out, + MetaConfig config = MetaConfig()); void ReduceInferMetaBase(const MetaTensor& x, const std::vector& axis, diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index 4df1e6e1629c02f32534854a9447141d14f46e1f..d02909f007da462089903d0f0764e2cf86231ede 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -29,7 +29,7 @@ void SplitKernel(const Context& dev_ctx, const Scalar& axis_scalar, std::vector outs) { // need to infershape output - if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { + if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) { std::vector out_metas; for (size_t i = 0; i < outs.size(); ++i) { out_metas.push_back(outs[i]); diff --git a/paddle/phi/kernels/gpu/split_kernel.cu b/paddle/phi/kernels/gpu/split_kernel.cu index 9d26b7361ff631e84ce38cfad07a4f79d06d53f5..919b0a7d4f9664c2df4c3f9e0c77200344911de6 100644 --- a/paddle/phi/kernels/gpu/split_kernel.cu +++ b/paddle/phi/kernels/gpu/split_kernel.cu @@ -28,7 +28,7 @@ void SplitKernel(const Context& dev_ctx, const Scalar& axis_scalar, std::vector outs) { // need to infershape output - if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { + if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) { std::vector out_metas; for (size_t i = 0; i < outs.size(); ++i) { out_metas.push_back(outs[i]); diff --git a/paddle/phi/kernels/reshape_kernel.cc b/paddle/phi/kernels/reshape_kernel.cc index 0a6aeb030e28d9482514f45cdf54228daf45b8f8..68d9130850191029c111fcfe42589af5962b60b3 100644 --- a/paddle/phi/kernels/reshape_kernel.cc +++ b/paddle/phi/kernels/reshape_kernel.cc @@ -47,7 +47,6 @@ void ReshapeWithXShape(const Context& dev_ctx, const ScalarArray& shape, DenseTensor* xshape, DenseTensor* out) { - funcs::SetXShape(x, xshape); ReshapeKernel(dev_ctx, x, shape, out); } diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index 353d364e0ce0b30c1372152fab9b59b5b0d80649..8e8b7592f909adaa3de7e4357b3adf9d812704c2 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -17,13 +17,19 @@ limitations under the License. */ namespace phi { KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.InputSize("ShapeTensor") > 0) { - return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); - } else if (ctx.HasInput("Shape")) { - return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); - } else { - return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + if (ctx.HasOutput("XShape")) { + if (ctx.InputSize("ShapeTensor") > 0) { + return KernelSignature( + "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"XShape", "Out"}); + } else if (ctx.HasInput("Shape")) { + return KernelSignature( + "reshape_with_xshape", {"X"}, {"Shape"}, {"XShape", "Out"}); + } else { + return KernelSignature( + "reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"}); + } } + return KernelSignature("unregistered", {}, {}, {}); } KernelSignature ReshapeGradOpArgumentMapping( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py index b23c7d9b493d0ce812503d6cdb89292e4cd6d04e..0522df3a9219d58c8a912f8b4554443899c2973a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py @@ -91,14 +91,10 @@ class TRTReshapeTest2(TRTReshapeTest): with fluid.program_guard(self.main_program, self.startup_program): data = fluid.data( name='data', shape=self.data_shape, dtype='float32') - actual_reshape = fluid.data( - name='actual_reshape', shape=[4], dtype='int32') - reshape_out = fluid.layers.reshape( - x=data, shape=self.reshape, actual_shape=actual_reshape) + reshape_out = fluid.layers.reshape(x=data, shape=self.reshape) out = fluid.layers.batch_norm(reshape_out, is_test=True) self.feeds = { - 'data': np.random.random(self.data_shape).astype('float32'), - 'actual_reshape': np.array([2, 0, -1, 6]).astype('int32') + 'data': np.random.random(self.data_shape).astype('float32') } self.enable_trt = True self.trt_parameters = TRTReshapeTest.TensorRTParam(