From 79b49c2008eb065c01313d1f78215ae999115256 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 16 Nov 2021 15:19:22 +0800 Subject: [PATCH] Add API and unit test for reshape (#37232) * reshape kernel refactor * fix compile bugs when run ci * support xpu for reshape * fix bugs when run unittest in kunlun ci * fix compile bugs when run kunlun * perfect code according to suggestion * add api and unit test for reshape --- paddle/fluid/framework/operator.cc | 11 ++- paddle/fluid/imperative/prepared_operator.cc | 10 ++- paddle/fluid/operators/reshape_op.cc | 4 +- paddle/pten/api/include/manipulation.h | 1 + paddle/pten/api/lib/manipulation.cc | 34 +++++++++ paddle/pten/core/kernel_utils.h | 1 - paddle/pten/include/manipulation.h | 2 +- paddle/pten/infermeta/unary.cc | 4 +- paddle/pten/infermeta/unary.h | 2 +- paddle/pten/kernels/cpu/manipulation.cc | 10 +-- paddle/pten/kernels/cpu/manipulation.h | 4 +- paddle/pten/kernels/cuda/manipulation.cu | 10 +-- paddle/pten/kernels/cuda/manipulation.h | 4 +- paddle/pten/kernels/xpu/manipulation.cc | 7 +- paddle/pten/kernels/xpu/manipulation.h | 2 +- paddle/pten/tests/api/CMakeLists.txt | 1 + paddle/pten/tests/api/test_reshape_api.cc | 70 +++++++++++++++++++ paddle/pten/tests/kernels/CMakeLists.txt | 1 + .../tests/kernels/test_reshape_dev_api.cc | 67 ++++++++++++++++++ 19 files changed, 218 insertions(+), 27 deletions(-) create mode 100644 paddle/pten/tests/api/test_reshape_api.cc create mode 100644 paddle/pten/tests/kernels/test_reshape_dev_api.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2eb054be497..8d275f8f1b7 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1884,9 +1884,14 @@ void OperatorWithKernel::BuildPtenKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - pt_kernel_context_->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); + std::type_index(typeid(std::vector)) && + std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + pt_kernel_context_->EmplaceBackAttr(vector_int64_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 658272b7c0d..9815983cc10 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -373,8 +373,14 @@ static void BuildDygraphPtenKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); + std::type_index(typeid(std::vector)) && + std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + kernel_ctx->EmplaceBackAttr(vector_int64_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index e3104fa0650..1a8725bd988 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -459,7 +459,9 @@ class ReshapeKernel { } #endif } else { - auto &shape_vec = ctx.Attr>("shape"); + auto &shape_attr = ctx.Attr>("shape"); + const std::vector shape_vec(shape_attr.begin(), + shape_attr.end()); if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); diff --git a/paddle/pten/api/include/manipulation.h b/paddle/pten/api/include/manipulation.h index 8c3bf5ae94d..e09e113732a 100644 --- a/paddle/pten/api/include/manipulation.h +++ b/paddle/pten/api/include/manipulation.h @@ -21,5 +21,6 @@ namespace experimental { PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis); +PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector& shape); } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/manipulation.cc b/paddle/pten/api/lib/manipulation.cc index 4b80b2010cc..3d9dba0458b 100644 --- a/paddle/pten/api/lib/manipulation.cc +++ b/paddle/pten/api/lib/manipulation.cc @@ -59,6 +59,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { return out; } + +PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector& shape) { + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "reshape2", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + auto dense_x = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(dense_x); + kernel_context.EmplaceBackAttr(shape); + + // 4. InferShape + auto out_meta = InferShapeFromVecValue(dense_x->meta(), shape); + + // 5. Prepare outputs + Tensor out; + const auto allocator = std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index c464519cb97..23143c06244 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -208,7 +208,6 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index d779b772b05..4900c78e63a 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -40,7 +40,7 @@ DenseTensor Flatten(const ContextT& dev_ctx, template DenseTensor Reshape(const ContextT& dev_ctx, const DenseTensor& x, - const std::vector& shape) { + const std::vector& shape) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); const auto allocator = std::make_shared( diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index e2f9e5fccc6..5099984886c 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -83,7 +83,7 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, } static paddle::framework::DDim ValidateShape( - const std::vector shape, const paddle::framework::DDim& in_dims) { + const std::vector shape, const paddle::framework::DDim& in_dims) { const int64_t in_size = paddle::framework::product(in_dims); auto in_dims_vec = paddle::framework::vectorize(in_dims); bool all_positive = std::all_of(in_dims_vec.cbegin(), @@ -203,7 +203,7 @@ static paddle::framework::DDim ValidateShape( } DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, - const std::vector& shape) { + const std::vector& shape) { PADDLE_ENFORCE_EQ(!shape.empty(), true, paddle::platform::errors::InvalidArgument( diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index cf88f0060e8..4e22c9bf2d8 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -46,5 +46,5 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DataLayout layout); DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, - const std::vector& shape); + const std::vector& shape); } // namespace pten diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index fcff674f36b..99699ea91ee 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -46,7 +46,7 @@ void FlattenWithXShape(const CPUContext& dev_ctx, void ReshapeFromVectorVal(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); if (&x == out) { @@ -59,7 +59,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { ReshapeFromVectorVal(dev_ctx, x, shape, out); @@ -71,8 +71,10 @@ void ReshapeFromDT(const CPUContext& dev_ctx, const DenseTensor& shape, DenseTensor* out) { auto* shape_data = shape.data(); - auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + auto vector_shape = + std::vector(shape_data, shape_data + shape.numel()); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); + out->set_lod(x.lod()); } void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, @@ -88,7 +90,7 @@ void ReshapeFromVectorDT(const CPUContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - std::vector vector_shape; + std::vector vector_shape; for (auto& tensor : shape) { PADDLE_ENFORCE_EQ( tensor.dims(), diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index c0747749451..435139e1fdf 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -36,7 +36,7 @@ void ReshapeFromDT(const CPUContext& dev_ctx, void ReshapeFromVectorVal(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out); void ReshapeFromVectorDT(const CPUContext& dev_ctx, @@ -52,7 +52,7 @@ void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out); diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index 47451226c76..b84694c0a9f 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -46,7 +46,7 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, void ReshapeFromVectorVal(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); if (&x == out) { @@ -60,7 +60,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { ReshapeFromVectorVal(dev_ctx, x, shape, out); @@ -72,8 +72,10 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, const DenseTensor& shape, DenseTensor* out) { auto* shape_data = shape.data(); - auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + auto vector_shape = + std::vector(shape_data, shape_data + shape.numel()); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); + out->set_lod(x.lod()); } void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, @@ -89,7 +91,7 @@ void ReshapeFromVectorDT(const CUDAContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - std::vector vector_shape; + std::vector vector_shape; for (auto& tensor : shape) { PADDLE_ENFORCE_EQ( tensor.dims(), diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index 6a071d6e49d..40be7670baa 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -40,7 +40,7 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, void ReshapeFromVectorVal(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out); void ReshapeFromVectorDT(const CUDAContext& dev_ctx, @@ -56,7 +56,7 @@ void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* xshape, DenseTensor* out); diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index 6ce143e5e39..f19ebf35a02 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -53,7 +53,7 @@ void FlattenWithXShape(const XPUContext& dev_ctx, void ReshapeFromVectorVal(const XPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out) { auto out_meta = InferShapeFromVecValue(x.meta(), shape); if (&x == out) { @@ -69,7 +69,8 @@ void ReshapeFromDT(const XPUContext& dev_ctx, const DenseTensor& shape, DenseTensor* out) { auto* shape_data = shape.data(); - auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + auto vector_shape = + std::vector(shape_data, shape_data + shape.numel()); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); } @@ -77,7 +78,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - std::vector vector_shape; + std::vector vector_shape; for (auto& tensor : shape) { PADDLE_ENFORCE_EQ( tensor.dims(), diff --git a/paddle/pten/kernels/xpu/manipulation.h b/paddle/pten/kernels/xpu/manipulation.h index 61a9536f8cc..b519a23a500 100644 --- a/paddle/pten/kernels/xpu/manipulation.h +++ b/paddle/pten/kernels/xpu/manipulation.h @@ -40,7 +40,7 @@ void ReshapeFromDT(const XPUContext& dev_ctx, void ReshapeFromVectorVal(const XPUContext& dev_ctx, const DenseTensor& x, - const std::vector& shape, + const std::vector& shape, DenseTensor* out); void ReshapeFromVectorDT(const XPUContext& dev_ctx, diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index 0e6349ef085..5bc5f0ace88 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -14,3 +14,4 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_a cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_reshape_api.cc b/paddle/pten/tests/api/test_reshape_api.cc new file mode 100644 index 00000000000..b1dd4c827ff --- /dev/null +++ b/paddle/pten/tests/api/test_reshape_api.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/manipulation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_DECLARE_MODULE(ManipulationCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(ManipulationCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(API, reshape) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2, 2, 3}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x->mutable_data(); + + for (int i = 0; i < dense_x->numel(); i++) { + dense_x_data[i] = i; + } + + paddle::experimental::Tensor x(dense_x); + std::vector shape{12, 3}; + // 2. test API + auto out = paddle::experimental::reshape(x, shape); + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.shape()[0], expect_shape[0]); + ASSERT_EQ(out.shape()[1], expect_shape[1]); + ASSERT_EQ(out.numel(), 36); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + bool value_equal = true; + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto* dense_out_data = dense_out->data(); + for (int i = 0; i < dense_x->numel(); i++) { + if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) + value_equal = false; + } + ASSERT_EQ(value_equal, true); +} diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index aa9050e3b4e..8a66fd18609 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -5,3 +5,4 @@ cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_uti cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils) cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils) cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_reshape_dev_api.cc b/paddle/pten/tests/kernels/test_reshape_dev_api.cc new file mode 100644 index 00000000000..c06cc8a8a40 --- /dev/null +++ b/paddle/pten/tests/kernels/test_reshape_dev_api.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/include/manipulation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(DEV_API, reshape) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2, 2, 3}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + for (int i = 0; i < dense_x.numel(); i++) { + dense_x_data[i] = i; + } + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + std::vector shape{12, 3}; + + // 2. test API + auto out = pten::Reshape( + *(static_cast(dev_ctx)), + dense_x, + shape); + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.dims()[0], expect_shape[0]); + ASSERT_EQ(out.dims()[1], expect_shape[1]); + ASSERT_EQ(out.numel(), 36); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + bool value_equal = true; + auto* dense_out_data = out.data(); + for (int i = 0; i < dense_x.numel(); i++) { + if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) + value_equal = false; + } + ASSERT_EQ(value_equal, true); +} -- GitLab