diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index 417a136c60b618d4418d072f31d12d6d2e175027..90724403d4bc7ee8c509c9a887e64b4946e18fb9 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -14,11 +14,14 @@ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/complex_functors.h" -#include "paddle/fluid/platform/for_range.h" + +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/cpu/conj_kernel.h" +#include "paddle/pten/kernels/cuda/conj_kernel.h" namespace paddle { namespace operators { @@ -30,16 +33,14 @@ class ConjKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { const Tensor* x = context.Input("X"); Tensor* out = context.Output("Out"); + out->mutable_data(context.GetPlace(), size_t(x->numel() * sizeof(T))); - auto numel = x->numel(); - auto* x_data = x->data(); - auto* out_data = out->mutable_data(context.GetPlace(), - size_t(x->numel() * sizeof(T))); + auto& dev_ctx = context.device_context(); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); - auto& dev_ctx = context.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - math::ConjFunctor functor(x_data, numel, out_data); - for_range(functor); + // call new kernel + pten::Conj(dev_ctx, *pt_x.get(), pt_out.get()); } }; diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 4ad99724fd6224f377488c268aa38d70e1b973a7..2066ce955cafe3f3727848ab0b8a73b0826c6851 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -20,6 +20,7 @@ #include #include "paddle/fluid/operators/conj_op.h" +#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/spectral_helper.h" #include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/transpose_op.h" diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index 0c5e0c9e6ec538dbd4fad590fe74f46ff2d4013f..cda991913db959c0dbe118c7dc9fcbf3bac8fd53 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -25,10 +25,10 @@ add_subdirectory(tests) # make an unity target for compile deps set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context) set(PTEN_DEPS ${PTEN_DEPS} scale_kernel_eigen full_kernel_eigen) -set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu manipulation_cpu) +set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu manipulation_cpu conj_kernel_cpu) set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) if(WITH_GPU OR WITH_ROCM) - set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda manipulation_cuda) + set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda manipulation_cuda conj_kernel_cuda) endif() if(WITH_XPU) set(PTEN_DEPS ${PTEN_DEPS} manipulation_xpu) diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 7b60ff12cf2ec1f729870d18afa6e81564166ec4..ebae064c336897c2aca58617b3d0dcc5d2292eb5 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -105,4 +105,8 @@ using subtract_kernel = void (*)(const DeviceContext&, int, DenseTensor*); +using conj_kernel = void (*)(const DeviceContext&, + const DenseTensor&, + DenseTensor*); + } // namespace pten diff --git a/paddle/pten/api/lib/kernel_declare.h b/paddle/pten/api/lib/kernel_declare.h index 0f4f82b9d7c51f72c59e14b5f8bac287ea850f7f..a4dd3af6f0d3de40352a21f6fd66edd92308c765 100644 --- a/paddle/pten/api/lib/kernel_declare.h +++ b/paddle/pten/api/lib/kernel_declare.h @@ -25,12 +25,14 @@ PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(conj, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT); PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT); PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT); PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT); +PT_DECLARE_KERNEL(conj, CUDA, ALL_LAYOUT); #endif #ifdef PADDLE_WITH_XPU diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index 0dfa17234e3bfb24f1b7dfd56c6d0b44890ac0be..8295c5765411dc62db94b8c772bb64879278435b 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -17,7 +17,9 @@ limitations under the License. */ // See Note: [ How do we organize the kernel directory ] #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/cpu/conj_kernel.h" #include "paddle/pten/kernels/cpu/math.h" +#include "paddle/pten/kernels/cuda/conj_kernel.h" #include "paddle/pten/kernels/cuda/math.h" #include "paddle/pten/kernels/scale_kernel.h" @@ -139,4 +141,16 @@ DenseTensor Multiply(const ContextT& dev_ctx, Multiply(dev_ctx, x, y, axis, &dense_out); return dense_out; } + +template +DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) { + auto out_meta = UnchangedInferMeta(x.meta()); + pten::DenseTensor dense_out( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(out_meta)); + Conj(dev_ctx, x, &dense_out); + return dense_out; +} + } // namespace pten diff --git a/paddle/pten/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt index f45d511602d71a175c5f917cd955e5be93a8f431..7a32fab2674c34f6cb7d7218661139977fa2fc1c 100644 --- a/paddle/pten/kernels/cpu/CMakeLists.txt +++ b/paddle/pten/kernels/cpu/CMakeLists.txt @@ -2,3 +2,4 @@ cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory) cc_library(utils_cpu SRCS utils.cc DEPS dense_tensor kernel_context kernel_factory memory convert_utils) cc_library(manipulation_cpu SRCS manipulation.cc DEPS dense_tensor kernel_context kernel_factory utils_cpu unary) +cc_library(conj_kernel_cpu SRCS conj_kernel.cc DEPS dense_tensor kernel_context kernel_factory) diff --git a/paddle/pten/kernels/cpu/conj_kernel.cc b/paddle/pten/kernels/cpu/conj_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f10d9f761eaed6cdcc12db9bf33846499e3b5c44 --- /dev/null +++ b/paddle/pten/kernels/cpu/conj_kernel.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/cpu/conj_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/hybird/math/conj_impl.h" + +namespace pten { + +template +void Conj(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out) { + ConjImpl(dev_ctx, x, out); +} + +} // namespace pten + +PT_REGISTER_KERNEL(conj, + CPU, + ALL_LAYOUT, + pten::Conj, + paddle::platform::complex, + paddle::platform::complex, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/conj_kernel.h b/paddle/pten/kernels/cpu/conj_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..49dad8f5b2df6a5f8d5f0d0386d7e81f63956515 --- /dev/null +++ b/paddle/pten/kernels/cpu/conj_kernel.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void Conj(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/kernels/cuda/CMakeLists.txt b/paddle/pten/kernels/cuda/CMakeLists.txt index b608e18f3da529fc4f672223c80e93c4be8ab88e..48b6dc1623442e240666e82355789d905ad0fa7c 100644 --- a/paddle/pten/kernels/cuda/CMakeLists.txt +++ b/paddle/pten/kernels/cuda/CMakeLists.txt @@ -3,9 +3,11 @@ if(WITH_GPU) nv_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) nv_library(utils_cuda SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils) nv_library(manipulation_cuda SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_cuda unary) + nv_library(conj_kernel_cuda SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) elseif(WITH_ROCM) hip_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_cuda) hip_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) hip_library(utils_cuda SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils) hip_library(manipulation_cuda SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_cuda unary) + hip_library(conj_kernel_cuda SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) endif() diff --git a/paddle/pten/kernels/cuda/conj_kernel.cu b/paddle/pten/kernels/cuda/conj_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f3d2296f562a0c18667fed1e71610e54ce35bf3d --- /dev/null +++ b/paddle/pten/kernels/cuda/conj_kernel.cu @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/cuda/conj_kernel.h" + +#include "paddle/pten/backends/cuda/cuda_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/hybird/math/conj_impl.h" + +namespace pten { + +template +void Conj(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out) { + ConjImpl(dev_ctx, x, out); +} + +} // namespace pten + +PT_REGISTER_KERNEL(conj, + CUDA, + ALL_LAYOUT, + pten::Conj, + paddle::platform::complex, + paddle::platform::complex, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cuda/conj_kernel.h b/paddle/pten/kernels/cuda/conj_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8ed0049d877650ff93d723c4a8425a6834183052 --- /dev/null +++ b/paddle/pten/kernels/cuda/conj_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// CUDA and HIP use same api +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#include "paddle/pten/backends/cuda/cuda_context.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void Conj(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out); + +} // namespace pten + +#endif diff --git a/paddle/pten/kernels/hybird/math/conj_impl.h b/paddle/pten/kernels/hybird/math/conj_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..84ad0b1a6ce95a08728c49318ca011e2f18bf904 --- /dev/null +++ b/paddle/pten/kernels/hybird/math/conj_impl.h @@ -0,0 +1,34 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/platform/for_range.h" + +namespace pten { + +template +void ConjImpl(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = out->mutable_data(); + + paddle::platform::ForRange for_range(dev_ctx, numel); + paddle::operators::math::ConjFunctor functor(x_data, numel, out_data); + for_range(functor); +} + +} // namespace pten diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index a230b6a4181875dff7b625c3277a23088341f4a2..2c494043e2760ad408fed2123b6e37625ecc8e12 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -22,3 +22,4 @@ cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils scale_kernel_eigen) +cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_conj_api.cc b/paddle/pten/tests/api/test_conj_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..928f8e414fda0f7216dd05e3270347d1a7a2ec98 --- /dev/null +++ b/paddle/pten/tests/api/test_conj_api.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/api.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace paddle { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(API, conj) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::COMPLEX64, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x->mutable_data(); + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = paddle::complex64(i * 10 + j, i * 10 + j); + } + } + + paddle::experimental::Tensor x(dense_x); + + // 2. test API + auto out = paddle::experimental::conj(x); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.dims()[1], 10); + ASSERT_EQ(out.numel(), 30); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::COMPLEX64); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto actual_result = dense_out->data(); + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = paddle::complex64(i * 10 + j, i * 10 + j); + ASSERT_NEAR(actual_result[i * 10 + j].real, 1.0 * (i * 10 + j), 1e-6f); + ASSERT_NEAR(actual_result[i * 10 + j].imag, -1.0 * (i * 10 + j), 1e-6f); + } + } +} + +} // namespace tests +} // namespace paddle diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index e2571570129d1bc1731d24c182f1919830d2bcfa..3a626aad2deb5dd8a9f3a09f3f5417b9394a36b0 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -8,3 +8,4 @@ cc_test(test_cast_dev_api SRCS test_cast_dev_api.cc DEPS pten pten_api_utils) cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils) cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils) cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_conj_dev_api.cc b/paddle/pten/tests/kernels/test_conj_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..0438a8f4f462bd4150d60dd8bd439de6290af2e1 --- /dev/null +++ b/paddle/pten/tests/kernels/test_conj_dev_api.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/include/math.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(DEV_API, conj) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::COMPLEX64, + framework::make_ddim({3, 4}), + pten::DataLayout::NCHW)); + + auto* dense_x_data = dense_x.mutable_data(); + for (size_t i = 0; i < 12; ++i) { + dense_x_data[i] = paddle::complex64(i * 1.0, i * 1.0); + } + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Conj( + *(static_cast(dev_ctx)), dense_x); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.numel(), 12); + ASSERT_EQ(out.meta().dtype, pten::DataType::COMPLEX64); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto actual_result = out.data(); + + for (size_t i = 0; i < 12; ++i) { + ASSERT_NEAR(i * 1.0, actual_result[i].real, 1e-6f); + ASSERT_NEAR(i * -1.0, actual_result[i].imag, 1e-6f); + } +} + +} // namespace tests +} // namespace pten diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 5a4ebb0179ca8164cefd7b8d442742d9e68a3e37..0c410d9b66fe99224bb002870a386c7edda7a700 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -154,3 +154,11 @@ # backend : [place, x] # layout : [] # InferMeta : UnchangedInferMeta(x) + +- api : conj + args : (const Tensor& x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : conj