diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index d824014713d93105718848a2406a7de098f4cfe2..b6f6151e13360441f1517bc9fe75c0dbc6a22249 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -203,7 +204,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx)) { @@ -214,6 +215,19 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 7f692d61649f873aa6e5e4337cb4db8b30fbb451..be10376f6111579377586a04a2cd8212cdcbd2e3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -289,6 +289,19 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { @@ -326,6 +339,19 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; class ElementwiseOpDoubleGradWithoutDXDY diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index db25d05c6b24346db3ac5c14a0fd0eacaf913c28..dab9948edc3592e8c1635c5bb62b7dfbd09dd1e1 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -134,6 +134,19 @@ class KronGradOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 7a3db793184d4839e779ef8c9ffcc72f09062cb1..6fccd3657af77eced2d11e97b96c865f6ab92e43 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -150,6 +150,27 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { context->SetOutputDim(y_grad_name, y_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index f93802c47c99a6beae55b5aac7f4c93332500586..32860a6694a893d494edacc4115e156e59ff4c15 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -320,6 +320,21 @@ class TestComplexElementwiseDivOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp): + def init_input_output(self): + self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x / self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = np.real(self.grad_out / np.conj(self.y)) + self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index f69fa7084edb193515a626e23b91cda6b1122829..7bace9bc535243194e2ed9ca82db49e6d1b4f2f4 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -304,6 +304,21 @@ class TestComplexElementwiseMulOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp): + def init_input_output(self): + self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x * self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = np.real(self.grad_out * np.conj(self.y)) + self.grad_y = self.grad_out * np.conj(self.x) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index 6434807c551105fbdbfa630b61eeac04b24c167c..c5372d5b758a8b2878c0a3070f32fa5db8efa117 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle from op_test import OpTest, skip_check_grad_ci @@ -164,5 +165,78 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): } +class TestComplexElementwiseSubOp(OpTest): + def setUp(self): + self.op_type = "elementwise_sub" + self.dtype = np.float64 + self.shape = (2, 3, 4, 5) + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x - self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = self.grad_out + self.grad_y = -self.grad_out + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestRealComplexElementwiseSubOp(TestComplexElementwiseSubOp): + def init_input_output(self): + self.x = np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x - self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = np.real(self.grad_out) + self.grad_y = -self.grad_out + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index 634739596e985e549137cf6450731dad7734655c..d6db4c2f074a9018266bdf3a9d38b14d878dc9cf 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -186,6 +186,20 @@ class TestComplexKronOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestKronOpTypePromotion(TestComplexKronOp): + def init_input_output(self): + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype( + self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype) + self.out = np.kron(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones( + self.out_shape, self.dtype) + self.grad_x = self.get_grad_x_by_numpy().real + self.grad_y = self.get_grad_y_by_numpy() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index f944f84c6c113530e4b5c602e4b4c0ed9bb6d9b9..761d318d7b8a3d43897f31bf635884b582fcec1d 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -525,6 +525,21 @@ class TestComplexMatMulOpBroadcast(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestMatMulTypePromotion(TestComplexMatMulOp): + def init_input_output(self): + self.x = np.random.random((10, 10)).astype(self.dtype) + self.y = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones( + (10, 10), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real + self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) + + if __name__ == "__main__": paddle.enable_static() unittest.main()