From c7371b7b204bfaa6f53f305836d2fe80b5002b87 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Mon, 11 Jan 2021 16:06:59 +0800 Subject: [PATCH] type promotion for grad (#30177) * type promotion for grad * add type promotion for div op --- .../elementwise/elementwise_div_op.h | 16 +++- .../operators/elementwise/elementwise_op.h | 26 +++++++ paddle/fluid/operators/kron_op.cc | 13 ++++ paddle/fluid/operators/matmul_v2_op.cc | 21 ++++++ .../unittests/test_elementwise_div_op.py | 15 ++++ .../unittests/test_elementwise_mul_op.py | 15 ++++ .../unittests/test_elementwise_sub_op.py | 74 +++++++++++++++++++ .../fluid/tests/unittests/test_kron_op.py | 14 ++++ .../tests/unittests/test_matmul_v2_op.py | 15 ++++ 9 files changed, 208 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index d824014713d..b6f6151e133 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -203,7 +204,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx)) { @@ -214,6 +215,19 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 7f692d61649..be10376f611 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -289,6 +289,19 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { @@ -326,6 +339,19 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; class ElementwiseOpDoubleGradWithoutDXDY diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index db25d05c6b2..dab9948edc3 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -134,6 +134,19 @@ class KronGradOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 7a3db793184..6fccd3657af 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -150,6 +150,27 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { context->SetOutputDim(y_grad_name, y_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index f93802c47c9..32860a6694a 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -320,6 +320,21 @@ class TestComplexElementwiseDivOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp): + def init_input_output(self): + self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x / self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = np.real(self.grad_out / np.conj(self.y)) + self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index f69fa7084ed..7bace9bc535 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -304,6 +304,21 @@ class TestComplexElementwiseMulOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp): + def init_input_output(self): + self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x * self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = np.real(self.grad_out * np.conj(self.y)) + self.grad_y = self.grad_out * np.conj(self.x) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index 6434807c551..c5372d5b758 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle from op_test import OpTest, skip_check_grad_ci @@ -164,5 +165,78 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): } +class TestComplexElementwiseSubOp(OpTest): + def setUp(self): + self.op_type = "elementwise_sub" + self.dtype = np.float64 + self.shape = (2, 3, 4, 5) + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x - self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = self.grad_out + self.grad_y = -self.grad_out + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + +class TestRealComplexElementwiseSubOp(TestComplexElementwiseSubOp): + def init_input_output(self): + self.x = np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x - self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = np.real(self.grad_out) + self.grad_y = -self.grad_out + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_kron_op.py b/python/paddle/fluid/tests/unittests/test_kron_op.py index 634739596e9..d6db4c2f074 100644 --- a/python/paddle/fluid/tests/unittests/test_kron_op.py +++ b/python/paddle/fluid/tests/unittests/test_kron_op.py @@ -186,6 +186,20 @@ class TestComplexKronOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestKronOpTypePromotion(TestComplexKronOp): + def init_input_output(self): + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype( + self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype) + self.out = np.kron(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones( + self.out_shape, self.dtype) + self.grad_x = self.get_grad_x_by_numpy().real + self.grad_y = self.get_grad_y_by_numpy() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index f944f84c6c1..761d318d7b8 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -525,6 +525,21 @@ class TestComplexMatMulOpBroadcast(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestMatMulTypePromotion(TestComplexMatMulOp): + def init_input_output(self): + self.x = np.random.random((10, 10)).astype(self.dtype) + self.y = np.random.random( + (10, 10)).astype(self.dtype) + 1J * np.random.random( + (10, 10)).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_grad_input_output(self): + self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones( + (10, 10), self.dtype) + self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real + self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab