未验证 提交 8fe43819 编写于 作者: J JYChen 提交者: GitHub

fix UT TestSetValueBFloat16 (#54497)

上级 5b97278e
...@@ -1657,22 +1657,35 @@ class TestSetValueIsSamePlace(unittest.TestCase): ...@@ -1657,22 +1657,35 @@ class TestSetValueIsSamePlace(unittest.TestCase):
class TestSetValueBFloat16(OpTest): class TestSetValueBFloat16(OpTest):
def setUp(self): def setUp(self):
self.dtype = np.uint16 self.dtype = np.uint16
self.shape = [2, 3, 4] self.shape = [22, 3, 4]
self.__class__.op_type = self.op_type self.op_type = 'set_value'
self.data = np.ones(self.shape).astype(self.dtype) self.data = np.ones(self.shape).astype(self.dtype)
x = np.random.rand([6]).astype('float32') value = np.random.rand(4).astype('float32')
self.data[0, 0] = np.random.rand([6]).astype('float32')
out = self.data[0, 0] expected_out = np.ones(self.shape).astype('float32')
self.inputs = {'X': convert_float_to_uint16(x)} expected_out[0, 0] = value
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {
'axes': [0, 1],
'starts': [0, 0],
'ends': [1, 1],
'steps': [1, 1],
}
self.inputs = {
'Input': convert_float_to_uint16(self.data),
'ValueTensor': convert_float_to_uint16(value),
}
self.outputs = {'Out': convert_float_to_uint16(expected_out)}
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) # NOTE(zoooo0820) Here we set check_dygraph=False since set_value OP has no corresponding python api
# to set self.python_api
self.check_output_with_place(place, check_dygraph=False)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['Input'], 'Out', check_dygraph=False)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册