未验证 提交 648563dd 编写于 作者: C chenxujun 提交者: GitHub

Add margin_cross_entropy, transfer_layout, dropout_nd tests (#52369)

上级 e4d20cdd
......@@ -591,7 +591,8 @@ PD_REGISTER_KERNEL(margin_cross_entropy,
phi::MarginCrossEntropyKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(margin_cross_entropy_grad,
GPU,
......@@ -599,4 +600,5 @@ PD_REGISTER_KERNEL(margin_cross_entropy_grad,
phi::MarginCrossEntropyGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -28,6 +28,11 @@ static __device__ __forceinline__ phi::dtype::float16 Exp(
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ phi::dtype::bfloat16 Exp(
phi::dtype::bfloat16 x) {
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float Exp(float x) { return expf(x); }
static __device__ __forceinline__ double Exp(double x) { return exp(x); }
......@@ -37,6 +42,11 @@ static __device__ __forceinline__ phi::dtype::float16 Log(
return ::Eigen::numext::log(x);
}
static __device__ __forceinline__ phi::dtype::bfloat16 Log(
phi::dtype::bfloat16 x) {
return ::Eigen::numext::log(x);
}
static __device__ __forceinline__ float Log(float x) { return logf(x); }
static __device__ __forceinline__ double Log(double x) { return log(x); }
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import _legacy_C_ops, fluid
......@@ -57,7 +57,7 @@ def dropout_nd(
helper = LayerHelper('dropout_nd', **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'dropout'
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -116,6 +116,64 @@ class TestDropoutNdOp(OpTest):
self.check_grad(['X'], 'Out', check_dygraph=False)
class TestDropoutNdFP16Op(OpTest):
def setUp(self):
self.op_type = "dropout_nd"
self.dtype = np.float16
self.inputs = {'X': np.random.random((2, 16, 8)).astype("float16")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'axis': [1],
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((1, 16, 1)).astype('uint8'),
}
def test_check_output(self):
self.check_output(check_dygraph=False)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_dygraph=False)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestDropoutNdBF16Op(OpTest):
def setUp(self):
self.op_type = "dropout_nd"
self.dtype = np.uint16
self.np_dtype = "float32"
self.inputs = {
'X': convert_float_to_uint16(
np.random.random((2, 16, 8)).astype(self.np_dtype)
)
}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'axis': [1],
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((1, 16, 1)).astype('uint8'),
}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0), check_dygraph=False)
def test_check_grad_normal(self):
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', check_dygraph=False
)
class TestDropoutNdAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, paddle_static_guard
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
import paddle
from paddle.fluid import Program, core, program_guard
......@@ -191,6 +191,91 @@ class TestMarginCrossEntropyOpFP16(TestMarginCrossEntropyOp):
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMarginCrossEntropyBF16Op(OpTest):
def initParams(self):
self.python_api = python_api
self.op_type = "margin_cross_entropy"
self.python_out_sig = ["Loss"]
self.axis = -1
self.batch_dim = 5
self.feat_dim = 41
self.num_class = 37
def init_loss_params(self):
self.margin1 = 1.0
self.margin2 = 0.5
self.margin3 = 0.0
self.scale = 2.0
def init_dtype(self):
self.dtype = np.uint16
# For bfloat16, converts float32 to uint16
self.np_dtype = "float32"
def setUp(self):
self.initParams()
self.init_loss_params()
self.init_dtype()
datas = np.random.uniform(
-0.99, 0.99, [self.batch_dim, self.feat_dim]
).astype(self.np_dtype)
datas = datas / np.sqrt(np.sum(np.square(datas), axis=1, keepdims=True))
weights = np.random.uniform(
-0.99, 0.99, [self.feat_dim, self.num_class]
).astype(self.np_dtype)
weights = weights / np.sqrt(
np.sum(np.square(weights), axis=0, keepdims=True)
)
logits = np.matmul(datas, weights)
labels = np.random.randint(
0, self.num_class, (self.batch_dim,), dtype="int64"
)
loss, softmax = margin_cross_entropy(
logits,
labels,
self.axis,
self.margin1,
self.margin2,
self.margin3,
self.scale,
)
self.inputs = {
"Logits": convert_float_to_uint16(logits),
"Label": labels,
}
self.outputs = {
"Softmax": convert_float_to_uint16(softmax.astype(self.np_dtype)),
"Loss": convert_float_to_uint16(loss.astype(self.np_dtype)),
}
self.attrs = {
'margin1': self.margin1,
'margin2': self.margin2,
'margin3': self.margin3,
'scale': self.scale,
}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0), atol=5e-2)
def test_check_grad(self):
self.check_grad_with_place(
core.CUDAPlace(0),
["Logits"],
"Loss",
numeric_grad_delta=6e-1,
max_relative_error=6e-1,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
......@@ -90,6 +90,43 @@ class TestTransferLayoutOpGpu(unittest.TestCase):
assert ret[0].shape == (n, h, w, c)
class TestTransferLayoutFP16Op(OpTest):
def setUp(self):
self.op_type = 'transfer_layout'
self.dtype = np.float16
x = np.random.random(size=[2, 5, 10, 10])
self.inputs = {'X': x.astype(self.dtype)}
self.outputs = {'Out': x.transpose([0, 2, 3, 1])}
self.attrs = {'src_layout': 0, 'dst_layout': 1}
self.python_api = transpose_layout
def test_check_output(self):
self.check_output()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestTransferLayoutBP16Op(OpTest):
def setUp(self):
self.op_type = 'transfer_layout'
self.dtype = np.uint16
x = np.random.random(size=[2, 5, 10, 10])
self.inputs = {'X': convert_float_to_uint16(x.astype('float32'))}
self.outputs = {
'Out': convert_float_to_uint16(
x.transpose([0, 2, 3, 1]), data_format="NHWC"
)
}
self.attrs = {'src_layout': 0, 'dst_layout': 1}
self.python_api = transpose_layout
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册