diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2eb054be49718c1e23e8758b1ad2d987eca5e6f1..8d275f8f1b74157880d8bea7b30feccf2736ab75 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1884,9 +1884,14 @@ void OperatorWithKernel::BuildPtenKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - pt_kernel_context_->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); + std::type_index(typeid(std::vector)) && + std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + pt_kernel_context_->EmplaceBackAttr(vector_int64_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 658272b7c0d1f1b42d7c06e68374a4877d2f4389..9815983cc10836431b8315a27b6fbbd4caacc82a 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -373,8 +373,14 @@ static void BuildDygraphPtenKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); + std::type_index(typeid(std::vector)) && + std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + kernel_ctx->EmplaceBackAttr(vector_int64_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index e3104fa06507eb5e4effd5d24d7a2459786d88a8..1a8725bd9886f83747e976121d71c62ac2cf9ec1 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -459,7 +459,9 @@ class ReshapeKernel { } #endif } else { - auto &shape_vec = ctx.Attr>("shape"); + auto &shape_attr = ctx.Attr>("shape"); + const std::vector shape_vec(shape_attr.begin(), + shape_attr.end()); if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); diff --git a/paddle/pten/api/include/manipulation.h b/paddle/pten/api/include/manipulation.h index 8c3bf5ae94d95b8c5cfd7711533990bf1992a876..e09e113732a6d8ca5b26e8073403b6366316b440 100644 --- a/paddle/pten/api/include/manipulation.h +++ b/paddle/pten/api/include/manipulation.h @@ -21,5 +21,6 @@ namespace experimental { PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis); +PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector& shape); } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/manipulation.cc b/paddle/pten/api/lib/manipulation.cc index 4b80b2010ccc07f5f0c333638aa1702cb683d064..3d9dba0458bd67b96b7fdd3f322168bafe2d3071 100644 --- a/paddle/pten/api/lib/manipulation.cc +++ b/paddle/pten/api/lib/manipulation.cc @@ -59,6 +59,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { return out; } + +PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector& shape) { + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "reshape2", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + auto dense_x = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(dense_x); + kernel_context.EmplaceBackAttr(shape); + + // 4. InferShape + auto out_meta = InferShapeFromVecValue(dense_x->meta(), shape); + + // 5. Prepare outputs + Tensor out; + const auto allocator = std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index c464519cb9773fe70104fc80c3026d420bcf440c..23143c06244ca93b39dc99e7f1eaaae234808255 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -208,7 +208,6 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index d779b772b053a67a379a1faee96557fb247cfa14..4900c78e63adafd4ed38e4e5cdb54f4efe69dee6 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -40,7 +40,7 @@ DenseTensor Flatten(const ContextT& dev_ctx, template DenseTensor Reshape(const ContextT& dev_ctx, const DenseTensor& x, - const std::vector& shape) { + const std::vector& shape) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); const auto allocator = std::make_shared( diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index e2f9e5fccc6d1e4f28a05e463e8b6c0e787ed07f..5099984886cce5d31b4a26a445f6cc2b01660b61 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -83,7 +83,7 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, } static paddle::framework::DDim ValidateShape( - const std::vector shape, const paddle::framework::DDim& in_dims) { + const std::vector shape, const paddle::framework::DDim& in_dims) { const int64_t in_size = paddle::framework::product(in_dims); auto in_dims_vec = paddle::framework::vectorize(in_dims); bool all_positive = std::all_of(in_dims_vec.cbegin(), @@ -203,7 +203,7 @@ static paddle::framework::DDim ValidateShape( } DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, - const std::vector& shape) { + const std::vector& shape) { PADDLE_ENFORCE_EQ(!shape.empty(), true, paddle::platform::errors::InvalidArgument( diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index cf88f0060e8a317caaaea9d13b15f474a53b2e75..4e22c9bf2d808e8b0de4a70e89743df5e33cf1fc 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -46,5 +46,5 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DataLayout layout); DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, - const std::vector& shape); + const std::vector& shape); } // namespace pten diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index fcff674f36b9ca9d2c1649197a740e64d47f9bbd..99699ea91ee2038af46e3f930d66ff4613257cfb 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -46,7 +46,7 @@ void FlattenWithXShape(const CPUContext& dev_ctx, void ReshapeFromVectorVal(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); if (&x == out) { @@ -59,7 +59,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { ReshapeFromVectorVal(dev_ctx, x, shape, out); @@ -71,8 +71,10 @@ void ReshapeFromDT(const CPUContext& dev_ctx, const DenseTensor& shape, DenseTensor* out) { auto* shape_data = shape.data(); - auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + auto vector_shape = + std::vector(shape_data, shape_data + shape.numel()); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); + out->set_lod(x.lod()); } void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, @@ -88,7 +90,7 @@ void ReshapeFromVectorDT(const CPUContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - std::vector vector_shape; + std::vector vector_shape; for (auto& tensor : shape) { PADDLE_ENFORCE_EQ( tensor.dims(), diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index c0747749451d0731fce560e6ac0564e96e1e8e57..435139e1fdfaf0214e3f541070b9edde6a8898b6 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -36,7 +36,7 @@ void ReshapeFromDT(const CPUContext& dev_ctx, void ReshapeFromVectorVal(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out); void ReshapeFromVectorDT(const CPUContext& dev_ctx, @@ -52,7 +52,7 @@ void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out); diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index 47451226c760b05853002fa2475c6d366b059a0d..b84694c0a9f813b643da60ec36e769f92df71663 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -46,7 +46,7 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, void ReshapeFromVectorVal(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); if (&x == out) { @@ -60,7 +60,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { ReshapeFromVectorVal(dev_ctx, x, shape, out); @@ -72,8 +72,10 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, const DenseTensor& shape, DenseTensor* out) { auto* shape_data = shape.data(); - auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + auto vector_shape = + std::vector(shape_data, shape_data + shape.numel()); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); + out->set_lod(x.lod()); } void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, @@ -89,7 +91,7 @@ void ReshapeFromVectorDT(const CUDAContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - std::vector vector_shape; + std::vector vector_shape; for (auto& tensor : shape) { PADDLE_ENFORCE_EQ( tensor.dims(), diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index 6a071d6e49d6695230f8e9a5f8a23c6e74c661fd..40be7670baa1f0a5474fbcc16493b7a833e27909 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -40,7 +40,7 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, void ReshapeFromVectorVal(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out); void ReshapeFromVectorDT(const CUDAContext& dev_ctx, @@ -56,7 +56,7 @@ void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out); diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index 6ce143e5e3950acb934cc7d150f6bd20b8425ebc..f19ebf35a0254a434b94b61c5dda3ec4cd18c85d 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -53,7 +53,7 @@ void FlattenWithXShape(const XPUContext& dev_ctx, void ReshapeFromVectorVal(const XPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); if (&x == out) { @@ -69,7 +69,8 @@ void ReshapeFromDT(const XPUContext& dev_ctx, const DenseTensor& shape, DenseTensor* out) { auto* shape_data = shape.data(); - auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + auto vector_shape = + std::vector(shape_data, shape_data + shape.numel()); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); } @@ -77,7 +78,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - std::vector vector_shape; + std::vector vector_shape; for (auto& tensor : shape) { PADDLE_ENFORCE_EQ( tensor.dims(), diff --git a/paddle/pten/kernels/xpu/manipulation.h b/paddle/pten/kernels/xpu/manipulation.h index 61a9536f8cc5845ccbc85b9b948451734d147824..b519a23a50038ec1d5405b2b0664dc51ee0a9b02 100644 --- a/paddle/pten/kernels/xpu/manipulation.h +++ b/paddle/pten/kernels/xpu/manipulation.h @@ -40,7 +40,7 @@ void ReshapeFromDT(const XPUContext& dev_ctx, void ReshapeFromVectorVal(const XPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out); void ReshapeFromVectorDT(const XPUContext& dev_ctx, diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index 0e6349ef085fa1d3248ba2f82888e1d0797f47d2..5bc5f0ace8804afcf3d07e9c76ab3394922cd7c5 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -14,3 +14,4 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_a cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_reshape_api.cc b/paddle/pten/tests/api/test_reshape_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1dd4c827ff7709172562142063ed1b8e27eeea0 --- /dev/null +++ b/paddle/pten/tests/api/test_reshape_api.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/manipulation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_DECLARE_MODULE(ManipulationCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(ManipulationCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(API, reshape) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2, 2, 3}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x->mutable_data(); + + for (int i = 0; i < dense_x->numel(); i++) { + dense_x_data[i] = i; + } + + paddle::experimental::Tensor x(dense_x); + std::vector shape{12, 3}; + // 2. test API + auto out = paddle::experimental::reshape(x, shape); + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.shape()[0], expect_shape[0]); + ASSERT_EQ(out.shape()[1], expect_shape[1]); + ASSERT_EQ(out.numel(), 36); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + bool value_equal = true; + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto* dense_out_data = dense_out->data(); + for (int i = 0; i < dense_x->numel(); i++) { + if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) + value_equal = false; + } + ASSERT_EQ(value_equal, true); +} diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index aa9050e3b4ec4e4d8a08d9c10e7a3f1fc0633871..8a66fd186091a1b45482d2fc185f59776407ddd9 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -5,3 +5,4 @@ cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_uti cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils) cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils) cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_reshape_dev_api.cc b/paddle/pten/tests/kernels/test_reshape_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..c06cc8a8a406bd128a1c4a2e4b2470db50b94ad3 --- /dev/null +++ b/paddle/pten/tests/kernels/test_reshape_dev_api.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/include/manipulation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(DEV_API, reshape) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2, 2, 3}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + for (int i = 0; i < dense_x.numel(); i++) { + dense_x_data[i] = i; + } + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + std::vector shape{12, 3}; + + // 2. test API + auto out = pten::Reshape( + *(static_cast(dev_ctx)), + dense_x, + shape); + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.dims()[0], expect_shape[0]); + ASSERT_EQ(out.dims()[1], expect_shape[1]); + ASSERT_EQ(out.numel(), 36); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + bool value_equal = true; + auto* dense_out_data = out.data(); + for (int i = 0; i < dense_x.numel(); i++) { + if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) + value_equal = false; + } + ASSERT_EQ(value_equal, true); +}