未验证 提交 0900a790 编写于 作者: J JYChen 提交者: GitHub

fix setvalue dtype error when using dy2st and amp O2 (#56868)

* fix setvalue dtype error when using dy2st and amp O2

* add one test

* remove test share_buffer since win/linux have different number
上级 c96b9cbb
......@@ -642,6 +642,10 @@ def cast_model_to_fp16(
def need_process(op):
need_process = True
if op.type in ["set_value"]:
# NOTE(zoooo0820): OP set_value has attribute "dtype", but its output type is
# determined by the input.dtype instead of attribute. So, here we still process it.
return need_process
if op.type in ["create_py_reader", "read"]:
need_process = False
else:
......
......@@ -258,5 +258,51 @@ class TestFp16Guard(AmpTestBase):
paddle.disable_static()
class SimpleModelIncludeSetValue(nn.Layer):
def __init__(self):
super().__init__()
self.norm = nn.LayerNorm(3)
def forward(self, x):
x = x + 1
tmp = x * 1
y = self.norm(tmp)
x[:] = y
z = x * 1
return z
@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
class TestDy2STWithSetValue(AmpTestBase):
def test_op_called_as_expected(self):
expected_fp16_calls = {
"cast": 0,
"layer_norm": 1,
"scale": 3,
"set_value": 1,
}
func = SimpleModelIncludeSetValue()
func = paddle.amp.decorate(func, level='O2')
func = paddle.jit.to_static(func)
input = paddle.randn((2, 3))
with paddle.amp.auto_cast(level='O2'):
res = func(input)
loss = res.sum()
prog = func.forward.get_concrete_program(input)[1].forward_program
amp.debugging.collect_operator_stats(prog)
op_stats_list = amp.debugging._get_op_stats_list(prog)
loss.backward()
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_fp16_calls
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册