未验证 提交 b94fe95a 编写于 作者: I Infinity_lee 提交者: GitHub

【Hackathon No58】fix atan2 (#51185)

上级 b647c2f0
......@@ -22,4 +22,5 @@ PD_REGISTER_KERNEL(atan2_grad,
phi::Atan2GradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(atan2,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
......
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid.core as core
......@@ -129,6 +129,35 @@ class TestAtan2API(unittest.TestCase):
run(place)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestAtan2BF16OP(OpTest):
def setUp(self):
self.op_type = 'atan2'
self.python_api = paddle.atan2
self.dtype = np.uint16
x1 = np.random.uniform(-1, -0.1, [15, 17]).astype('float32')
x2 = np.random.uniform(0.1, 1, [15, 17]).astype('float32')
out = np.arctan2(x1, x2)
self.inputs = {
'X1': convert_float_to_uint16(x1),
'X2': convert_float_to_uint16(x2),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X1', 'X2'], 'Out')
class TestAtan2Error(unittest.TestCase):
def test_mismatch(self):
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册