未验证 提交 79b49c20 编写于 作者: Y YuanRisheng 提交者: GitHub

Add API and unit test for reshape (#37232)

* reshape kernel refactor

* fix compile bugs when run ci

* support xpu for reshape

* fix bugs when run unittest in kunlun ci

* fix compile bugs when run kunlun

* perfect code according to suggestion

* add api and unit test for reshape
上级 6ebc318e
...@@ -1884,9 +1884,14 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1884,9 +1884,14 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>)) &&
std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
pt_kernel_context_->EmplaceBackAttr( // Emplace Back Attr according to the type of Pten_Kernel args.
BOOST_GET_CONST(std::vector<int>, attr)); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "unsupported cast op attribute `%s` when construct "
......
...@@ -373,8 +373,14 @@ static void BuildDygraphPtenKernelContext( ...@@ -373,8 +373,14 @@ static void BuildDygraphPtenKernelContext(
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>)) &&
std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr)); // Emplace Back Attr according to the type of Pten_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "unsupported cast op attribute `%s` when construct "
......
...@@ -459,7 +459,9 @@ class ReshapeKernel { ...@@ -459,7 +459,9 @@ class ReshapeKernel {
} }
#endif #endif
} else { } else {
auto &shape_vec = ctx.Attr<std::vector<int>>("shape"); auto &shape_attr = ctx.Attr<std::vector<int>>("shape");
const std::vector<int64_t> shape_vec(shape_attr.begin(),
shape_attr.end());
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
......
...@@ -21,5 +21,6 @@ namespace experimental { ...@@ -21,5 +21,6 @@ namespace experimental {
PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis); PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis);
PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector<int64_t>& shape);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -59,6 +59,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { ...@@ -59,6 +59,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) {
return out; return out;
} }
PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector<int64_t>& shape) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"reshape2", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(shape);
// 4. InferShape
auto out_meta = InferShapeFromVecValue(dense_x->meta(), shape);
// 5. Prepare outputs
Tensor out;
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
......
...@@ -208,7 +208,6 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> { ...@@ -208,7 +208,6 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
/* Output Helpers */ /* Output Helpers */
......
...@@ -40,7 +40,7 @@ DenseTensor Flatten(const ContextT& dev_ctx, ...@@ -40,7 +40,7 @@ DenseTensor Flatten(const ContextT& dev_ctx,
template <typename T, typename ContextT> template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx, DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape) { const std::vector<int64_t>& shape) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferShapeFromVecValue(x.meta(), shape);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
......
...@@ -83,7 +83,7 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, ...@@ -83,7 +83,7 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
} }
static paddle::framework::DDim ValidateShape( static paddle::framework::DDim ValidateShape(
const std::vector<int> shape, const paddle::framework::DDim& in_dims) { const std::vector<int64_t> shape, const paddle::framework::DDim& in_dims) {
const int64_t in_size = paddle::framework::product(in_dims); const int64_t in_size = paddle::framework::product(in_dims);
auto in_dims_vec = paddle::framework::vectorize(in_dims); auto in_dims_vec = paddle::framework::vectorize(in_dims);
bool all_positive = std::all_of(in_dims_vec.cbegin(), bool all_positive = std::all_of(in_dims_vec.cbegin(),
...@@ -203,7 +203,7 @@ static paddle::framework::DDim ValidateShape( ...@@ -203,7 +203,7 @@ static paddle::framework::DDim ValidateShape(
} }
DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int>& shape) { const std::vector<int64_t>& shape) {
PADDLE_ENFORCE_EQ(!shape.empty(), PADDLE_ENFORCE_EQ(!shape.empty(),
true, true,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
......
...@@ -46,5 +46,5 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, ...@@ -46,5 +46,5 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataLayout layout); DataLayout layout);
DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int>& shape); const std::vector<int64_t>& shape);
} // namespace pten } // namespace pten
...@@ -46,7 +46,7 @@ void FlattenWithXShape(const CPUContext& dev_ctx, ...@@ -46,7 +46,7 @@ void FlattenWithXShape(const CPUContext& dev_ctx,
void ReshapeFromVectorVal(const CPUContext& dev_ctx, void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) { if (&x == out) {
...@@ -59,7 +59,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx, ...@@ -59,7 +59,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx,
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
ReshapeFromVectorVal(dev_ctx, x, shape, out); ReshapeFromVectorVal(dev_ctx, x, shape, out);
...@@ -71,8 +71,10 @@ void ReshapeFromDT(const CPUContext& dev_ctx, ...@@ -71,8 +71,10 @@ void ReshapeFromDT(const CPUContext& dev_ctx,
const DenseTensor& shape, const DenseTensor& shape,
DenseTensor* out) { DenseTensor* out) {
auto* shape_data = shape.data<int>(); auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel()); auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->set_lod(x.lod());
} }
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
...@@ -88,7 +90,7 @@ void ReshapeFromVectorDT(const CPUContext& dev_ctx, ...@@ -88,7 +90,7 @@ void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<DenseTensor>& shape, const std::vector<DenseTensor>& shape,
DenseTensor* out) { DenseTensor* out) {
std::vector<int> vector_shape; std::vector<int64_t> vector_shape;
for (auto& tensor : shape) { for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor.dims(), tensor.dims(),
......
...@@ -36,7 +36,7 @@ void ReshapeFromDT(const CPUContext& dev_ctx, ...@@ -36,7 +36,7 @@ void ReshapeFromDT(const CPUContext& dev_ctx,
void ReshapeFromVectorVal(const CPUContext& dev_ctx, void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeFromVectorDT(const CPUContext& dev_ctx, void ReshapeFromVectorDT(const CPUContext& dev_ctx,
...@@ -52,7 +52,7 @@ void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, ...@@ -52,7 +52,7 @@ void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out); DenseTensor* out);
......
...@@ -46,7 +46,7 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, ...@@ -46,7 +46,7 @@ void FlattenWithXShape(const CUDAContext& dev_ctx,
void ReshapeFromVectorVal(const CUDAContext& dev_ctx, void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) { if (&x == out) {
...@@ -60,7 +60,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx, ...@@ -60,7 +60,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
ReshapeFromVectorVal(dev_ctx, x, shape, out); ReshapeFromVectorVal(dev_ctx, x, shape, out);
...@@ -72,8 +72,10 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, ...@@ -72,8 +72,10 @@ void ReshapeFromDT(const CUDAContext& dev_ctx,
const DenseTensor& shape, const DenseTensor& shape,
DenseTensor* out) { DenseTensor* out) {
auto* shape_data = shape.data<int>(); auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel()); auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->set_lod(x.lod());
} }
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
...@@ -89,7 +91,7 @@ void ReshapeFromVectorDT(const CUDAContext& dev_ctx, ...@@ -89,7 +91,7 @@ void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<DenseTensor>& shape, const std::vector<DenseTensor>& shape,
DenseTensor* out) { DenseTensor* out) {
std::vector<int> vector_shape; std::vector<int64_t> vector_shape;
for (auto& tensor : shape) { for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor.dims(), tensor.dims(),
......
...@@ -40,7 +40,7 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, ...@@ -40,7 +40,7 @@ void ReshapeFromDT(const CUDAContext& dev_ctx,
void ReshapeFromVectorVal(const CUDAContext& dev_ctx, void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeFromVectorDT(const CUDAContext& dev_ctx, void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
...@@ -56,7 +56,7 @@ void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, ...@@ -56,7 +56,7 @@ void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out); DenseTensor* out);
......
...@@ -53,7 +53,7 @@ void FlattenWithXShape(const XPUContext& dev_ctx, ...@@ -53,7 +53,7 @@ void FlattenWithXShape(const XPUContext& dev_ctx,
void ReshapeFromVectorVal(const XPUContext& dev_ctx, void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferShapeFromVecValue(x.meta(), shape);
if (&x == out) { if (&x == out) {
...@@ -69,7 +69,8 @@ void ReshapeFromDT(const XPUContext& dev_ctx, ...@@ -69,7 +69,8 @@ void ReshapeFromDT(const XPUContext& dev_ctx,
const DenseTensor& shape, const DenseTensor& shape,
DenseTensor* out) { DenseTensor* out) {
auto* shape_data = shape.data<int>(); auto* shape_data = shape.data<int>();
auto vector_shape = std::vector<int>(shape_data, shape_data + shape.numel()); auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
} }
...@@ -77,7 +78,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx, ...@@ -77,7 +78,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<DenseTensor>& shape, const std::vector<DenseTensor>& shape,
DenseTensor* out) { DenseTensor* out) {
std::vector<int> vector_shape; std::vector<int64_t> vector_shape;
for (auto& tensor : shape) { for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor.dims(), tensor.dims(),
......
...@@ -40,7 +40,7 @@ void ReshapeFromDT(const XPUContext& dev_ctx, ...@@ -40,7 +40,7 @@ void ReshapeFromDT(const XPUContext& dev_ctx,
void ReshapeFromVectorVal(const XPUContext& dev_ctx, void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& shape, const std::vector<int64_t>& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeFromVectorDT(const XPUContext& dev_ctx, void ReshapeFromVectorDT(const XPUContext& dev_ctx,
......
...@@ -14,3 +14,4 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_a ...@@ -14,3 +14,4 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_a
cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/manipulation.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, reshape) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 2, 2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();
for (int i = 0; i < dense_x->numel(); i++) {
dense_x_data[i] = i;
}
paddle::experimental::Tensor x(dense_x);
std::vector<int64_t> shape{12, 3};
// 2. test API
auto out = paddle::experimental::reshape(x, shape);
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.shape()[0], expect_shape[0]);
ASSERT_EQ(out.shape()[1], expect_shape[1]);
ASSERT_EQ(out.numel(), 36);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
bool value_equal = true;
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto* dense_out_data = dense_out->data<float>();
for (int i = 0; i < dense_x->numel(); i++) {
if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f)
value_equal = false;
}
ASSERT_EQ(value_equal, true);
}
...@@ -5,3 +5,4 @@ cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_uti ...@@ -5,3 +5,4 @@ cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_uti
cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils) cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils) cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils) cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(DEV_API, reshape) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensor dense_x(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 2, 2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
for (int i = 0; i < dense_x.numel(); i++) {
dense_x_data[i] = i;
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<int64_t> shape{12, 3};
// 2. test API
auto out = pten::Reshape<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
shape);
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.dims()[0], expect_shape[0]);
ASSERT_EQ(out.dims()[1], expect_shape[1]);
ASSERT_EQ(out.numel(), 36);
ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32);
ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW);
bool value_equal = true;
auto* dense_out_data = out.data<float>();
for (int i = 0; i < dense_x.numel(); i++) {
if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f)
value_equal = false;
}
ASSERT_EQ(value_equal, true);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册