未验证 提交 b02de1b6 编写于 作者: C co63oc 提交者: GitHub

【Hackathon No.61】uniform_random 算子FP16/BF16单测完善 (#52949)

上级 e85fbac8
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/kernels/uniform_inplace_grad_kernel.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
......@@ -41,4 +42,6 @@ PD_REGISTER_KERNEL(uniform_inplace_grad,
ALL_LAYOUT,
phi::UniformInplaceGradKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/random.h>
#include "gflags/gflags.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
......@@ -72,8 +73,12 @@ void UniformInplaceKernel(const Context& ctx,
funcs::distribution_and_transform<T>(ctx, out, dist, trans);
} else {
// Use OP seed
auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
auto func = UniformGenerator<T>(static_cast<T>(min),
static_cast<T>(max),
seed,
diag_num,
diag_step,
static_cast<T>(diag_val));
IndexKernel<T, UniformGenerator<T>>(ctx, out, func);
}
}
......@@ -85,4 +90,6 @@ PD_REGISTER_KERNEL(uniform_inplace,
ALL_LAYOUT,
phi::UniformInplaceKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,9 +15,19 @@
import unittest
import numpy as np
from eager_op_test import OpTest, convert_uint16_to_float
import paddle
from paddle import fluid
from paddle.fluid import core
def output_hist(out):
hist, _ = np.histogram(out, range=(-1, 1))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones(10)
return hist, prob
class TestUniformRandomInplaceOpDtype(unittest.TestCase):
......@@ -44,6 +54,72 @@ class TestUniformRandomInplaceOpDtype(unittest.TestCase):
test_fp64()
class TestUniformRandomInplaceFP16Op(OpTest):
def setUp(self):
self.op_type = "uniform_random_inplace"
self.dtype = np.float16
self.shape = (1000, 784)
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
self.inputs = {"X": x}
self.outputs = {"Out": y}
self.init_attrs()
def init_attrs(self):
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.001)
# TODO: Due to the lack of the self.python_api=paddle.uniform_random_inplace setting, the dynamic graph is temporarily turned off, set check_dygraph=False
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_dygraph=False)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestUniformRandomInplaceBF16Op(OpTest):
def setUp(self):
self.op_type = "uniform_random_inplace"
self.dtype = np.uint16
self.shape = (1000, 784)
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
self.inputs = {'X': x}
self.outputs = {'Out': y}
self.init_attrs()
self.place = core.CUDAPlace(0)
def init_attrs(self):
self.output_hist = output_hist
def test_check_output(self):
self.check_output_with_place_customized(self.verify_output, self.place)
def verify_output(self, outs):
result = convert_uint16_to_float(np.array(outs[0]))
hist, prob = self.output_hist(result)
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.002)
# TODO: Due to the lack of the self.python_api=paddle.uniform_random_inplace setting, the dynamic graph is temporarily turned off, set check_dygraph=False
def test_check_grad(self):
grads = [paddle.zeros(self.shape, dtype=self.dtype)]
self.check_grad_with_place(
self.place,
['X'],
'Out',
check_dygraph=False,
user_defined_grads=grads,
)
class TestUniformRandomInplaceOpIsInplace(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册