未验证 提交 879e913b 编写于 作者: C chentianyu03 提交者: GitHub

Make transpose, trace, kron, reshape, sum op support complex type (#29321)

* add complex64 and complex128 type; add +-*/@ and slice opreator for complex types

* add test cases for complex elementwise, matmul and getitem unittest

* add test cases for complex types

* add test cases for complex matmul unittest

* kron, reshape, transpose support complex types

* sum and trace op support complex types

* add test case of sum and trace op

* fix the bug of imag part of complex not initialized

* format file

* format code style

* kron support type promotion; modify test cases
上级 66fd1c00
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/kron_op.h" #include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -51,8 +53,22 @@ class KronOp : public framework::OperatorWithKernel { ...@@ -51,8 +53,22 @@ class KronOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
} }
}; };
...@@ -154,7 +170,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -154,7 +170,11 @@ REGISTER_OP_CPU_KERNEL(
ops::KronKernel<paddle::platform::CPUDeviceContext, ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int>, ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp); REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -163,4 +183,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -163,4 +183,8 @@ REGISTER_OP_CPU_KERNEL(
ops::KronGradKernel<paddle::platform::CPUDeviceContext, ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>, ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/kron_op.h" #include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -22,7 +24,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -22,7 +24,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronKernel<paddle::platform::CUDADeviceContext, ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int>, ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>, kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -30,4 +36,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -30,4 +36,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronGradKernel<paddle::platform::CUDADeviceContext, ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>, ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
...@@ -115,6 +115,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -115,6 +115,12 @@ REGISTER_OP_CPU_KERNEL(
ops::SumFunctor>, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>, ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t, ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128,
ops::SumFunctor>); ops::SumFunctor>);
template <typename T> template <typename T>
...@@ -125,4 +131,6 @@ using CPUReduceSumGradKernel = ...@@ -125,4 +131,6 @@ using CPUReduceSumGradKernel =
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>, REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>,
CPUReduceSumGradKernel<double>, CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>, CPUReduceSumGradKernel<int>,
CPUReduceSumGradKernel<int64_t>); CPUReduceSumGradKernel<int64_t>,
CPUReduceSumGradKernel<paddle::platform::complex64>,
CPUReduceSumGradKernel<paddle::platform::complex128>);
...@@ -72,4 +72,6 @@ class ReduceSumKernel : public framework::OpKernel<T> { ...@@ -72,4 +72,6 @@ class ReduceSumKernel : public framework::OpKernel<T> {
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>, REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>, ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>); ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex64>,
ops::ReduceSumKernel<paddle::platform::complex128>);
...@@ -618,25 +618,25 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, ...@@ -618,25 +618,25 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInferer, ops::ReshapeDoubleGradInplaceInferer,
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer); ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(
ops::ReshapeKernel, int8_t, ops::ReshapeKernel, reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t,
uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel,
ops::ReshapeKernel, int64_t, ops::ReshapeKernel, int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel,
bool, ops::ReshapeKernel, paddle::platform::bfloat16, ops::ReshapeKernel, paddle::platform::complex64,
paddle::platform::bfloat16, ops::ReshapeKernel); ops::ReshapeKernel, paddle::platform::complex128, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, REGISTER_OP_CPU_KERNEL_FUNCTOR(
double, ops::ReshapeGradKernel, int, reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, paddle::platform::complex64, ops::ReshapeGradKernel,
ops::ReshapeGradKernel); paddle::platform::complex128, ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float, REGISTER_OP_CPU_KERNEL_FUNCTOR(
ops::ReshapeDoubleGradKernel, double, reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, paddle::platform::complex64,
ops::ReshapeDoubleGradKernel, bool, ops::ReshapeDoubleGradKernel, paddle::platform::complex128,
ops::ReshapeDoubleGradKernel); ops::ReshapeDoubleGradKernel);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -656,34 +656,38 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ...@@ -656,34 +656,38 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
uint8_t, ops::ReshapeKernel, int64_t, uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel); ops::ReshapeKernel, bool, ops::ReshapeKernel,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, plat::complex64, ops::ReshapeKernel,
double, ops::ReshapeGradKernel, int, plat::complex128, ops::ReshapeKernel);
ops::ReshapeGradKernel, uint8_t, REGISTER_OP_CUDA_KERNEL_FUNCTOR(
ops::ReshapeGradKernel, int64_t, reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel); ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex64,
ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(
ops::ReshapeDoubleGradKernel, int, reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel,
ops::ReshapeDoubleGradKernel, plat::float16, plat::float16, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, bool, ops::ReshapeDoubleGradKernel, plat::complex64, ops::ReshapeDoubleGradKernel,
ops::ReshapeDoubleGradKernel); plat::complex128, ops::ReshapeDoubleGradKernel);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16, int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel); ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel);
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, plat::complex64,
ops::ReshapeGradKernel, plat::complex128,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
#endif #endif
...@@ -163,9 +163,17 @@ REGISTER_OP_CPU_KERNEL( ...@@ -163,9 +163,17 @@ REGISTER_OP_CPU_KERNEL(
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>, trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>, ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>, ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::TraceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>, trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
...@@ -60,11 +60,19 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -60,11 +60,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
platform::float16>, platform::float16>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>, ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>); ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>, trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
platform::float16>, platform::float16>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>); ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
...@@ -321,11 +321,19 @@ REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); ...@@ -321,11 +321,19 @@ REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>, transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>); ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>, ops::Transpose2GradMaker<paddle::framework::OpDesc>,
...@@ -336,10 +344,18 @@ REGISTER_OP_CPU_KERNEL( ...@@ -336,10 +344,18 @@ REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>, transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>); ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
...@@ -730,14 +730,21 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -730,14 +730,21 @@ REGISTER_OP_CUDA_KERNEL(
transpose, transpose,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>); paddle::platform::complex64>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>); plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
...@@ -745,8 +752,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -745,8 +752,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>); paddle::platform::complex64>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
...@@ -754,4 +764,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -754,4 +764,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>); plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
...@@ -124,6 +124,7 @@ struct PADDLE_ALIGN(8) complex64 { ...@@ -124,6 +124,7 @@ struct PADDLE_ALIGN(8) complex64 {
HOSTDEVICE inline complex64& operator=(int32_t val) { HOSTDEVICE inline complex64& operator=(int32_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
......
...@@ -27,42 +27,68 @@ class ComplexKronTestCase(unittest.TestCase): ...@@ -27,42 +27,68 @@ class ComplexKronTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.ref_result = np.kron(self.x, self.y) self.ref_result = np.kron(self.x, self.y)
self._places = [paddle.CPUPlace()]
if fluid.is_compiled_with_cuda():
self._places.append(paddle.CUDAPlace(0))
def runTest(self): def runTest(self):
place = fluid.CPUPlace() for place in self._places:
self.test_identity(place) self.test_complex_api(place)
self.test_basic_api(place)
if fluid.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self.test_identity(place)
def test_identity(self, place): def test_complex_api(self, place):
with dg.guard(place): with dg.guard(place):
x_var = dg.to_variable(self.x) x_var = dg.to_variable(self.x)
y_var = dg.to_variable(self.y) y_var = dg.to_variable(self.y)
out_var = paddle.complex.kron(x_var, y_var) out_var = paddle.complex.kron(x_var, y_var)
np.testing.assert_allclose(out_var.numpy(), self.ref_result) self.assertTrue(np.allclose(out_var.numpy(), self.ref_result))
def test_basic_api(self, place):
with dg.guard(place):
x_var = paddle.Tensor(
value=self.x,
place=place,
persistable=False,
zero_copy=None,
stop_gradient=True)
y_var = paddle.Tensor(
value=self.y,
place=place,
persistable=False,
zero_copy=None,
stop_gradient=True)
out_var = tensor.math.kron(x_var, y_var)
self.assertTrue(np.allclose(out_var.numpy(), self.ref_result))
def load_tests(loader, standard_tests, pattern): def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite() suite = unittest.TestSuite()
for dtype in ["float32", "float64"]:
suite.addTest( suite.addTest(
ComplexKronTestCase( ComplexKronTestCase(
x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2), x=np.random.randn(2, 2).astype(dtype) + 1j * np.random.randn(
y=np.random.randn(3, 3) + 1j * np.random.randn(3, 3))) 2, 2).astype(dtype),
y=np.random.randn(3, 3).astype(dtype) + 1j * np.random.randn(
3, 3).astype(dtype)))
suite.addTest( suite.addTest(
ComplexKronTestCase( ComplexKronTestCase(
x=np.random.randn(2, 2), x=np.random.randn(2, 2).astype(dtype),
y=np.random.randn(3, 3) + 1j * np.random.randn(3, 3))) y=np.random.randn(3, 3).astype(dtype) + 1j * np.random.randn(
3, 3).astype(dtype)))
suite.addTest( suite.addTest(
ComplexKronTestCase( ComplexKronTestCase(
x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2), x=np.random.randn(2, 2).astype(dtype) + 1j * np.random.randn(
y=np.random.randn(3, 3))) 2, 2).astype(dtype),
y=np.random.randn(3, 3).astype(dtype)))
suite.addTest( suite.addTest(
ComplexKronTestCase( ComplexKronTestCase(
x=np.random.randn(2, 2) + 1j * np.random.randn(2, 2), x=np.random.randn(2, 2).astype(dtype) + 1j * np.random.randn(
y=np.random.randn(2, 2, 3))) 2, 2).astype(dtype),
y=np.random.randn(2, 2, 3).astype(dtype)))
return suite return suite
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
from paddle import complex as cpx from paddle import complex as cpx
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import numpy as np import numpy as np
...@@ -20,30 +21,71 @@ import unittest ...@@ -20,30 +21,71 @@ import unittest
class TestComplexReshape(unittest.TestCase): class TestComplexReshape(unittest.TestCase):
def setUp(self):
self._dtypes = ["float32", "float64"]
self._places = [paddle.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(paddle.CUDAPlace(0))
def test_case1(self): def test_case1(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4) for dtype in self._dtypes:
x_np = np.random.randn(
2, 3, 4).astype(dtype) + 1j * np.random.randn(2, 3,
4).astype(dtype)
shape = (2, -1) shape = (2, -1)
for place in self._places:
place = fluid.CPUPlace()
with dg.guard(place): with dg.guard(place):
x_var = dg.to_variable(x_np) x_var = dg.to_variable(x_np)
y_var = cpx.reshape(x_var, shape) y_var = cpx.reshape(x_var, shape)
y_np = y_var.numpy() y_np = y_var.numpy()
np.testing.assert_allclose(np.reshape(x_np, shape), y_np) np.testing.assert_allclose(np.reshape(x_np, shape), y_np)
def test_case2(self): def test_case2(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4) for dtype in self._dtypes:
x_np = np.random.randn(
2, 3, 4).astype(dtype) + 1j * np.random.randn(2, 3,
4).astype(dtype)
shape = (0, -1) shape = (0, -1)
shape_ = (2, 12) shape_ = (2, 12)
for place in self._places:
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
with dg.guard(place): with dg.guard(place):
x_var = dg.to_variable(x_np) x_var = dg.to_variable(x_np)
y_var = cpx.reshape(x_var, shape, inplace=True) y_var = cpx.reshape(x_var, shape, inplace=True)
y_np = y_var.numpy() y_np = y_var.numpy()
np.testing.assert_allclose(np.reshape(x_np, shape_), y_np)
def test_case3(self):
for dtype in self._dtypes:
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
shape = (2, -1)
for place in self._places:
with dg.guard(place):
x_var = paddle.Tensor(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
stop_gradient=True)
y_var = fluid.layers.reshape(x_var, shape)
y_np = y_var.numpy()
np.testing.assert_allclose(np.reshape(x_np, shape), y_np)
def test_case4(self):
for dtype in self._dtypes:
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
shape = (0, -1)
shape_ = (2, 12)
for place in self._places:
with dg.guard(place):
x_var = paddle.Tensor(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
stop_gradient=True)
y_var = fluid.layers.reshape(x_var, shape)
y_np = y_var.numpy()
np.testing.assert_allclose(np.reshape(x_np, shape_), y_np) np.testing.assert_allclose(np.reshape(x_np, shape_), y_np)
......
...@@ -14,22 +14,25 @@ ...@@ -14,22 +14,25 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle
from numpy.random import random as rand from numpy.random import random as rand
from paddle import complex as cpx from paddle import complex as cpx
from paddle import tensor
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
class TestComplexSumLayer(unittest.TestCase): class TestComplexSumLayer(unittest.TestCase):
def setUp(self): def setUp(self):
self._dtype = "float64" self._dtypes = ["float32", "float64"]
self._places = [fluid.CPUPlace()] self._places = [paddle.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0)) self._places.append(paddle.CUDAPlace(0))
def test_complex_x(self): def test_complex_x(self):
input = rand([2, 10, 10]).astype(self._dtype) + 1j * rand( for dtype in self._dtypes:
[2, 10, 10]).astype(self._dtype) input = rand([2, 10, 10]).astype(dtype) + 1j * rand(
[2, 10, 10]).astype(dtype)
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
var_x = dg.to_variable(input) var_x = dg.to_variable(input)
...@@ -37,6 +40,22 @@ class TestComplexSumLayer(unittest.TestCase): ...@@ -37,6 +40,22 @@ class TestComplexSumLayer(unittest.TestCase):
target = np.sum(input, axis=(1, 2)) target = np.sum(input, axis=(1, 2))
self.assertTrue(np.allclose(result, target)) self.assertTrue(np.allclose(result, target))
def test_complex_basic_api(self):
for dtype in self._dtypes:
input = rand([2, 10, 10]).astype(dtype) + 1j * rand(
[2, 10, 10]).astype(dtype)
for place in self._places:
with dg.guard(place):
var_x = paddle.Tensor(
value=input,
place=place,
persistable=False,
zero_copy=None,
stop_gradient=True)
result = tensor.sum(var_x, axis=[1, 2]).numpy()
target = np.sum(input, axis=(1, 2))
self.assertTrue(np.allclose(result, target))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -14,26 +14,47 @@ ...@@ -14,26 +14,47 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle
from numpy.random import random as rand from numpy.random import random as rand
from paddle import complex as cpx from paddle import complex as cpx
from paddle import tensor
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
class TestComplexTraceLayer(unittest.TestCase): class TestComplexTraceLayer(unittest.TestCase):
def setUp(self): def setUp(self):
self._dtype = "float64" self._dtypes = ["float32", "float64"]
self._places = [fluid.CPUPlace()] self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0)) self._places.append(fluid.CUDAPlace(0))
def test_complex_x(self): def test_complex_api(self):
input = rand([2, 20, 2, 3]).astype(self._dtype) + 1j * rand( for dtype in self._dtypes:
[2, 20, 2, 3]).astype(self._dtype) input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand(
[2, 20, 2, 3]).astype(dtype)
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
var_x = dg.to_variable(input) var_x = dg.to_variable(input)
result = cpx.trace(var_x, offset=1, axis1=0, axis2=2).numpy() result = cpx.trace(
var_x, offset=1, axis1=0, axis2=2).numpy()
target = np.trace(input, offset=1, axis1=0, axis2=2)
self.assertTrue(np.allclose(result, target))
def test_basic_api(self):
for dtype in self._dtypes:
input = rand([2, 20, 2, 3]).astype(dtype) + 1j * rand(
[2, 20, 2, 3]).astype(dtype)
for place in self._places:
with dg.guard(place):
var_x = paddle.Tensor(
value=input,
place=place,
persistable=False,
zero_copy=None,
stop_gradient=True)
result = tensor.trace(
var_x, offset=1, axis1=0, axis2=2).numpy()
target = np.trace(input, offset=1, axis1=0, axis2=2) target = np.trace(input, offset=1, axis1=0, axis2=2)
self.assertTrue(np.allclose(result, target)) self.assertTrue(np.allclose(result, target))
......
...@@ -21,14 +21,16 @@ import paddle.fluid.dygraph as dg ...@@ -21,14 +21,16 @@ import paddle.fluid.dygraph as dg
class TestComplexTransposeLayer(unittest.TestCase): class TestComplexTransposeLayer(unittest.TestCase):
def setUp(self): def setUp(self):
self._places = [fluid.CPUPlace()] self._dtypes = ["float32", "float64"]
self._places = [paddle.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0)) self._places.append(paddle.CUDAPlace(0))
def test_identity(self): def test_transpose_by_complex_api(self):
for dtype in self._dtypes:
data = np.random.random( data = np.random.random(
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random( (2, 3, 4, 5)).astype(dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype("float32") (2, 3, 4, 5)).astype(dtype)
perm = [3, 2, 0, 1] perm = [3, 2, 0, 1]
np_trans = np.transpose(data, perm) np_trans = np.transpose(data, perm)
for place in self._places: for place in self._places:
...@@ -37,6 +39,24 @@ class TestComplexTransposeLayer(unittest.TestCase): ...@@ -37,6 +39,24 @@ class TestComplexTransposeLayer(unittest.TestCase):
trans = paddle.complex.transpose(var, perm=perm) trans = paddle.complex.transpose(var, perm=perm)
self.assertTrue(np.allclose(trans.numpy(), np_trans)) self.assertTrue(np.allclose(trans.numpy(), np_trans))
def test_transpose_by_basic_api(self):
for dtype in self._dtypes:
data = np.random.random(
(2, 3, 4, 5)).astype(dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(dtype)
perm = [3, 2, 0, 1]
np_trans = np.transpose(data, perm)
for place in self._places:
with dg.guard(place):
var = paddle.Tensor(
value=data,
place=place,
persistable=False,
zero_copy=None,
stop_gradient=True)
trans = paddle.transpose(var, perm=perm)
self.assertTrue(np.allclose(trans.numpy(), np_trans))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册