Created by: zhouwei25
PR types
Bug fixes
PR changes
APIs
Describe
Fix GradientClipByGlobalNorm dtype bug.
import paddle
import numpy as np
paddle.disable_static()
paddle.set_default_dtype('float64')
linear = paddle.nn.Linear(10, 10)
print(linear.weight.dtype)
print(linear.bias.dtype)
out = linear(paddle.to_tensor(np.random.rand(10, 10)))
out.backward()
optimizer = paddle.optimizer.Adam(0.1, parameters=linear.parameters(), grad_clip=paddle.nn.GradientClipByGlobalNorm(1.0))
optimizer.step()