未验证 提交 1f7f8561 编写于 作者: Q QingshuChen 提交者: GitHub

update kunlun label_smooth unitest (#39611)

* update kunlun label_smooth unitest
*test=kunlun

* minor
*test=kunlun
上级 8f2d14ad
......@@ -20,45 +20,66 @@ import numpy as np
import sys
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class TestLabelSmoothOp(XPUOpTest):
def config(self):
self.op_type = "label_smooth"
self.epsilon = 0.1
self.use_xpu = True
batch_size, self.label_dim = 10, 12
self.label = np.zeros((batch_size, self.label_dim)).astype("float32")
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1
class XPUTestLabelSmoothOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'label_smooth'
self.use_dynamic_create_class = True
def setUp(self):
self.config()
smoothed_label = (1 - self.epsilon
) * self.label + self.epsilon / self.label_dim
self.inputs = {'X': self.label}
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}
def dynamic_create_class(self):
base_class = self.TestLabelSmoothOp
classes = []
batch_sizes = [1, 5, 1024]
label_dims = [1, 7, 12]
for bs in batch_sizes:
for label_dim in label_dims:
class_name = 'XPUTestLabelSmooth_' + \
str(bs) + "_" + str(label_dim)
attr_dict = {'batch_size': bs, 'label_dim': label_dim}
classes.append([class_name, attr_dict])
classes.append(['XPUTestLabelSmooth_3d', {'is_3d': True}])
return base_class, classes
def test_check_output(self):
if not paddle.is_compiled_with_xpu():
return
self.check_output_with_place(paddle.XPUPlace(0), atol=1e-6)
class TestLabelSmoothOp(XPUOpTest):
def setUp(self):
self.op_type = "label_smooth"
self.epsilon = 0.1
self.use_xpu = True
if not hasattr(self, 'batch_size'):
self.batch_size = 10
self.label_dim = 12
self.label = np.zeros(
(self.batch_size, self.label_dim)).astype("float32")
nonzero_index = np.random.randint(
self.label_dim, size=(self.batch_size))
self.label[np.arange(self.batch_size), nonzero_index] = 1
smoothed_label = (1 - self.epsilon
) * self.label + self.epsilon / self.label_dim
self.inputs = {'X': self.label}
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}
if hasattr(self, 'is_3d') and self.is_3d:
self.inputs['X'] = self.inputs['X'].reshape(
[2, -1, self.inputs['X'].shape[-1]])
self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs[
'X'].shape)
def test_check_grad(self):
return
def test_check_output(self):
if not paddle.is_compiled_with_xpu():
return
self.check_output_with_place(paddle.XPUPlace(0), atol=1e-6)
def test_check_grad(self):
return
class TestLabelSmoothOp3D(TestLabelSmoothOp):
def setUp(self):
super(TestLabelSmoothOp3D, self).setUp()
self.inputs['X'] = self.inputs['X'].reshape(
[2, -1, self.inputs['X'].shape[-1]])
self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X']
.shape)
support_types = get_xpu_op_support_types('label_smooth')
for stype in support_types:
create_test_class(globals(), XPUTestLabelSmoothOp, stype)
if __name__ == '__main__':
unittest.main()
......@@ -213,8 +213,8 @@ fi
NO_NPU_FILE=`git diff --name-only upstream/$BRANCH | grep -v "_npu.py"`
HAS_UNITTEST_SKIP=`git diff -U0 upstream/$BRANCH ${NO_NPU_FILE} | grep "^+[[:space:]]\{0,\}@unittest.skip" || true`
if [ "${HAS_UNITTEST_SKIP}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then
echo_line="Unittest is not allowed to be disabled.\nYou must have one RD (kolinwei(Recommend), wanghuancoder, luotao1 or qili93) approval for the usage of @unittest.skip or @unittest.skipIf.\n${HAS_UNITTEST_SKIP}\n"
check_approval 1 22165420 6836917 46661762 26922892 16605440
echo_line="Unittest is not allowed to be disabled.\nYou must have one RD (kolinwei(Recommend), wanghuancoder, luotao1, QingshuChen or qili93) approval for the usage of @unittest.skip or @unittest.skipIf.\n${HAS_UNITTEST_SKIP}\n"
check_approval 1 22165420 6836917 46661762 26922892 16605440 2002279
fi
HAS_MODIFIED_DEMO_CMAKE=`git diff --name-only upstream/$BRANCH | grep "paddle/fluid/inference/api/demo_ci/CMakeLists.txt" || true`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册