未验证 提交 c7371b7b 编写于 作者: C chentianyu03 提交者: GitHub

type promotion for grad (#30177)

* type promotion for grad

* add type promotion for div op
上级 6d14659f
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
...@@ -203,7 +204,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -203,7 +204,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
...@@ -214,6 +215,19 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -214,6 +215,19 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -289,6 +289,19 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -289,6 +289,19 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
...@@ -326,6 +339,19 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -326,6 +339,19 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
class ElementwiseOpDoubleGradWithoutDXDY class ElementwiseOpDoubleGradWithoutDXDY
......
...@@ -134,6 +134,19 @@ class KronGradOp : public framework::OperatorWithKernel { ...@@ -134,6 +134,19 @@ class KronGradOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
template <typename T> template <typename T>
......
...@@ -150,6 +150,27 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -150,6 +150,27 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
context->SetOutputDim(y_grad_name, y_dims); context->SetOutputDim(y_grad_name, y_dims);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
template <typename T> template <typename T>
......
...@@ -320,6 +320,21 @@ class TestComplexElementwiseDivOp(OpTest): ...@@ -320,6 +320,21 @@ class TestComplexElementwiseDivOp(OpTest):
user_defined_grad_outputs=[self.grad_out]) user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp):
def init_input_output(self):
self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x / self.y
def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype)
self.grad_x = np.real(self.grad_out / np.conj(self.y))
self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -304,6 +304,21 @@ class TestComplexElementwiseMulOp(OpTest): ...@@ -304,6 +304,21 @@ class TestComplexElementwiseMulOp(OpTest):
user_defined_grad_outputs=[self.grad_out]) user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp):
def init_input_output(self):
self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x * self.y
def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype)
self.grad_x = np.real(self.grad_out * np.conj(self.y))
self.grad_y = self.grad_out * np.conj(self.x)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
...@@ -164,5 +165,78 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): ...@@ -164,5 +165,78 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
} }
class TestComplexElementwiseSubOp(OpTest):
def setUp(self):
self.op_type = "elementwise_sub"
self.dtype = np.float64
self.shape = (2, 3, 4, 5)
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.out = self.x - self.y
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones(
self.shape, self.dtype)
self.grad_x = self.grad_out
self.grad_y = -self.grad_out
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseSubOp(TestComplexElementwiseSubOp):
def init_input_output(self):
self.x = np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.out = self.x - self.y
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones(
self.shape, self.dtype)
self.grad_x = np.real(self.grad_out)
self.grad_y = -self.grad_out
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -186,6 +186,20 @@ class TestComplexKronOp(OpTest): ...@@ -186,6 +186,20 @@ class TestComplexKronOp(OpTest):
user_defined_grad_outputs=[self.grad_out]) user_defined_grad_outputs=[self.grad_out])
class TestKronOpTypePromotion(TestComplexKronOp):
def init_input_output(self):
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(
self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype)
self.out = np.kron(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones(
self.out_shape, self.dtype)
self.grad_x = self.get_grad_x_by_numpy().real
self.grad_y = self.get_grad_y_by_numpy()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -525,6 +525,21 @@ class TestComplexMatMulOpBroadcast(OpTest): ...@@ -525,6 +525,21 @@ class TestComplexMatMulOpBroadcast(OpTest):
user_defined_grad_outputs=[self.grad_out]) user_defined_grad_outputs=[self.grad_out])
class TestMatMulTypePromotion(TestComplexMatMulOp):
def init_input_output(self):
self.x = np.random.random((10, 10)).astype(self.dtype)
self.y = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones(
(10, 10), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册