diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index fc0f1416cc13896d24406fe471504d7badad7a61..b257f345eaf36c61eab61e9efa42bf3df6b5faa4 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -28,7 +28,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/linalg.h" +#include "paddle/pten/kernels/matmul_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" @@ -384,8 +384,8 @@ class MatMulV2Kernel : public framework::OpKernel { auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out); // call new kernel - pten::Matmul(dev_ctx, *pt_x.get(), *pt_y.get(), trans_x, trans_y, - pt_out.get()); + pten::MatmulKernel(dev_ctx, *pt_x, *pt_y, trans_x, trans_y, + pt_out.get()); } }; diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index 5a06ef11add72fdfdf9c6f9396d04f816d154ba6..7adfca40319b1ba5c415a5122132d2e7aa504b38 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -28,10 +28,10 @@ get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) # keep this message for debug, remove it later if needless message(STATUS "All standard pten kernels: ${pten_kernels}") set(PTEN_DEPS ${PTEN_DEPS} ${pten_kernels}) -set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu) +set(PTEN_DEPS ${PTEN_DEPS} math_cpu) set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) if(WITH_GPU OR WITH_ROCM) - set(PTEN_DEPS ${PTEN_DEPS} math_gpu linalg_gpu) + set(PTEN_DEPS ${PTEN_DEPS} math_gpu) endif() cc_library(pten SRCS all.cc DEPS ${PTEN_DEPS}) diff --git a/paddle/pten/api/lib/kernel_declare.h b/paddle/pten/api/lib/kernel_declare.h index 3b7d5ef157bbe30fb224986710688ea220a0a0b4..484063df478aaf4d8245e157fce640967bb6d5d6 100644 --- a/paddle/pten/api/lib/kernel_declare.h +++ b/paddle/pten/api/lib/kernel_declare.h @@ -20,10 +20,8 @@ limitations under the License. */ // the kernel declare statement is automatically generated according to the // file name of the kernel, and this header file will be removed -PT_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); PT_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT); #endif diff --git a/paddle/pten/include/linalg.h b/paddle/pten/include/linalg.h index 34b0183778125596d79cbd6e2249944be5b025e5..22f287468e673d63eb77399b4fa864ea46fa4989 100644 --- a/paddle/pten/include/linalg.h +++ b/paddle/pten/include/linalg.h @@ -17,9 +17,7 @@ // See Note: [ How do we organize the kernel directory ] #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/include/infermeta.h" -#include "paddle/pten/kernels/cpu/linalg.h" #include "paddle/pten/kernels/dot_kernel.h" -#include "paddle/pten/kernels/gpu/linalg.h" namespace pten { diff --git a/paddle/pten/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt index 24ce40c2451dff70f5fa6d23003473cb251711d7..9bf3df598e4c03a38452fd8d0666bf10242bb7de 100644 --- a/paddle/pten/kernels/cpu/CMakeLists.txt +++ b/paddle/pten/kernels/cpu/CMakeLists.txt @@ -1,2 +1 @@ cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas pten_transpose_cpu cast_kernel) -cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory) diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc deleted file mode 100644 index 0b58b36c596465248a4c0dad59f0fff31b63eb8e..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/cpu/linalg.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/kernels/cpu/linalg.h" - -#include "paddle/pten/core/kernel_registry.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/complex.h" - -#include "paddle/pten/kernels/hybird/math/matmul_func.h" - -namespace pten { - -template -void Matmul(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), - 0, - paddle::platform::errors::InvalidArgument( - "The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), - 0, - paddle::platform::errors::InvalidArgument( - "The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - math::MatMulFunction( - dev_ctx, x, y, out, transpose_x, transpose_y); -} - -} // namespace pten - -using complex64 = ::paddle::platform::complex; -using complex128 = ::paddle::platform::complex; - -PT_REGISTER_KERNEL(matmul, - CPU, - ALL_LAYOUT, - pten::Matmul, - float, - double, - complex64, - complex128) {} diff --git a/paddle/pten/kernels/cpu/linalg.h b/paddle/pten/kernels/cpu/linalg.h deleted file mode 100644 index d9fc391996e19893d0ca71ee502f1a65695592b1..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/cpu/linalg.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/backends/cpu/cpu_context.h" -#include "paddle/pten/core/dense_tensor.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/device_context.h" - -namespace pten { - -template -void Matmul(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out); - -} // namespace pten diff --git a/paddle/pten/kernels/cpu/matmul_kernel.cc b/paddle/pten/kernels/cpu/matmul_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..edba402ec1d842a97132de1c611a5785137ac52b --- /dev/null +++ b/paddle/pten/kernels/cpu/matmul_kernel.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul, + CPU, + ALL_LAYOUT, + pten::MatmulKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/CMakeLists.txt b/paddle/pten/kernels/gpu/CMakeLists.txt index b35dd3ce303d70fc7971dbf85d39bd8625898b23..51c666947b2f2cbc36a56584b97b4ad471ffb7a9 100644 --- a/paddle/pten/kernels/gpu/CMakeLists.txt +++ b/paddle/pten/kernels/gpu/CMakeLists.txt @@ -1,7 +1,5 @@ if(WITH_GPU) nv_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel copy_kernel) - nv_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) elseif(WITH_ROCM) hip_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel copy_kernel) - hip_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) endif() diff --git a/paddle/pten/kernels/gpu/linalg.cu b/paddle/pten/kernels/gpu/linalg.cu deleted file mode 100644 index e4a69b28e6158b767e915bb809d6f6cc3712c684..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/gpu/linalg.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/pten/kernels/gpu/linalg.h" - -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/hybird/math/matmul_func.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/complex.h" - -namespace pten { - -template -void Matmul(const GPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), - 0, - paddle::platform::errors::InvalidArgument( - "The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), - 0, - paddle::platform::errors::InvalidArgument( - "The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - math::MatMulFunction( - dev_ctx, x, y, out, transpose_x, transpose_y); -} - -} // namespace pten - -using float16 = paddle::platform::float16; -using complex64 = ::paddle::platform::complex; -using complex128 = ::paddle::platform::complex; - -PT_REGISTER_KERNEL(matmul, - GPU, - ALL_LAYOUT, - pten::Matmul, - float, - double, - float16, - complex64, - complex128) {} diff --git a/paddle/pten/kernels/gpu/linalg.h b/paddle/pten/kernels/gpu/linalg.h deleted file mode 100644 index a0f7c0c0aae229acae882c06aa3505e8762e4217..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/gpu/linalg.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -// CUDA and HIP use same api -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/core/dense_tensor.h" - -namespace pten { - -template -void Matmul(const GPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out); - -} // namespace pten - -#endif diff --git a/paddle/pten/kernels/gpu/matmul_kernel.cu b/paddle/pten/kernels/gpu/matmul_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..debda455818a952769d55ef0ad0d2143e9670242 --- /dev/null +++ b/paddle/pten/kernels/gpu/matmul_kernel.cu @@ -0,0 +1,31 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_kernel.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul, + GPU, + ALL_LAYOUT, + pten::MatmulKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/hybird/math/matmul_func.h b/paddle/pten/kernels/impl/matmul_kernel_impl.h similarity index 93% rename from paddle/pten/kernels/hybird/math/matmul_func.h rename to paddle/pten/kernels/impl/matmul_kernel_impl.h index 8aa8750aba4180b401fdbf639e956572dc25de17..e50b2f0641a46a8940ae6eb69fec24e125e0abb9 100644 --- a/paddle/pten/kernels/hybird/math/matmul_func.h +++ b/paddle/pten/kernels/impl/matmul_kernel_impl.h @@ -20,7 +20,6 @@ limitations under the License. */ #include "paddle/pten/core/dense_tensor.h" namespace pten { -namespace math { static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims, @@ -86,8 +85,8 @@ static void IndexIncreaseFromDims(const int ndim, } } -template -void MatMulFunction(const DeviceContext& dev_ctx, +template +void MatMulFunction(const Context& context, const DenseTensor& X, const DenseTensor& Y, const std::vector& x_dims, @@ -103,7 +102,7 @@ void MatMulFunction(const DeviceContext& dev_ctx, const T* x_data = X.data(); const T* y_data = Y.data(); - auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto blas = paddle::operators::math::GetBlas(context); if (x_ndim == 1 && y_ndim == 1) { const int M = X.numel(); @@ -471,8 +470,8 @@ void MatMulFunction(const DeviceContext& dev_ctx, } } -template -void MatMulFunction(const DeviceContext& dev_ctx, +template +void MatMulFunction(const Context& context, const DenseTensor& X, const DenseTensor& Y, DenseTensor* Out, @@ -481,9 +480,28 @@ void MatMulFunction(const DeviceContext& dev_ctx, bool flag = false) { const std::vector x_dims = vectorize(X.dims()); const std::vector y_dims = vectorize(Y.dims()); - MatMulFunction( - dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); + MatMulFunction( + context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); +} + +template +void MatmulKernel(const Context& context, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + MatMulFunction(context, x, y, out, transpose_x, transpose_y); } -} // namespace math } // namespace pten diff --git a/paddle/pten/kernels/matmul_kernel.h b/paddle/pten/kernels/matmul_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5a1766330b7781a169d8082eaf41ab317e9ff888 --- /dev/null +++ b/paddle/pten/kernels/matmul_kernel.h @@ -0,0 +1,46 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/include/infermeta.h" + +namespace pten { + +template +void MatmulKernel(const Context& context, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out); + +template +DenseTensor Matmul(const Context& context, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y) { + auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y); + DenseTensor dense_out( + pten::make_intrusive( + context.GetPlace()), + std::move(out_meta)); + MatmulKernel(context, x, y, transpose_x, transpose_y, &dense_out); + return dense_out; +} + +} // namespace pten diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index 5f14554a0cceeb3a7abce4f0f128b9c9c5254d64..6f70f2ca2c895a042c1aad0a43b4ff70966f256a 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -2,6 +2,7 @@ cc_test(test_copy_dev_api SRCS test_copy_dev_api.cc DEPS pten pten_api_utils) cc_test(test_dot_dev_api SRCS test_dot_dev_api.cc DEPS pten pten_api_utils) cc_test(test_creation_dev_api SRCS test_creation_dev_api.cc DEPS pten pten_api_utils) cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_matmul_dev_api SRCS test_matmul_dev_api.cc DEPS pten pten_api_utils) cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils) cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils) cc_test(test_cast_dev_api SRCS test_cast_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_matmul_dev_api.cc b/paddle/pten/tests/kernels/test_matmul_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ac3d19554581f31d95bb13155beb30d1786c8c9 --- /dev/null +++ b/paddle/pten/tests/kernels/test_matmul_dev_api.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/kernels/matmul_kernel.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(DEV_API, dot) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto* dense_x_data = dense_x.mutable_data(); + + DenseTensor dense_y(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(); + + for (size_t i = 0; i < 9; ++i) { + dense_x_data[i] = 1.0; + dense_y_data[i] = 2.0; + } + std::vector sum(9, 6.0); + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = Matmul( + *(static_cast(ctx)), dense_x, dense_y, false, false); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.dims()[1], 3); + ASSERT_EQ(out.numel(), 9); + ASSERT_EQ(out.dtype(), DataType::FLOAT32); + ASSERT_EQ(out.layout(), DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + for (size_t i = 0; i < 9; i++) { + ASSERT_NEAR(sum[i], out.data()[i], 1e-6f); + } +} + +} // namespace tests +} // namespace pten