未验证 提交 1980e33a 编写于 作者: L Leo Chen 提交者: GitHub

add check for backward hook (#40041)

* add check for backward hook

* refine ut
上级 09258040
......@@ -317,6 +317,7 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
auto tmp_var = var;
for (const auto& hook_pair : var->GetVariableWrapperHooks()) {
tmp_var = (*hook_pair.second)(tmp_var);
CheckVar(var, tmp_var);
}
(*tmp_ins_ptr)[pair.first][i] = tmp_var;
}
......
......@@ -732,6 +732,7 @@ void GradientAccumulator::CallGradientHooks() {
<< var_->GetVariableWrapperHooks().size();
for (const auto& hook_pair : var_->GetVariableWrapperHooks()) {
tmp_var = (*hook_pair.second)(tmp_var);
CheckVar(inner_var_, tmp_var);
}
inner_var_ = tmp_var;
}
......
......@@ -179,5 +179,29 @@ void SelectedRowsAddTensor(const VarType& src_selected_rows_var,
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst);
inline void CheckVar(const std::shared_ptr<VariableWrapper>& pre,
const std::shared_ptr<VariableWrapper>& post) {
if (pre->IsEmpty() && !post->IsEmpty()) {
PADDLE_THROW(platform::errors::PermissionDenied(
"The tensor(%s) in before and after hook are not consistent",
pre->Name()));
}
if (!pre->IsEmpty() && !post->IsEmpty()) {
VLOG(4) << pre->DataType() << " " << post->DataType();
PADDLE_ENFORCE_EQ(
pre->DataType(), post->DataType(),
platform::errors::PermissionDenied(
"The dtype of tensor(%s) before(%s) and after(%s) hook are not "
"consistent",
pre->Name(), framework::DataTypeToString(pre->DataType()),
framework::DataTypeToString(post->DataType())));
PADDLE_ENFORCE_EQ(pre->Place(), post->Place(),
platform::errors::PermissionDenied(
"The place of tensor(%s) before(%s) and after(%s) "
"hook are not consistent",
pre->Name(), pre->Place(), post->Place()));
}
}
} // namespace imperative
} // namespace paddle
......@@ -1156,7 +1156,7 @@ class TestBf16(unittest.TestCase):
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
class TestPyLayerWithAmp(unittest.TestCase):
class TestAmpWithPyLyer(unittest.TestCase):
def test_pylayer(self):
class MyMM(PyLayer):
@staticmethod
......@@ -1168,7 +1168,7 @@ class TestPyLayerWithAmp(unittest.TestCase):
def backward(ctx, grad):
a, b = ctx.saved_tensor()
# NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast()
# thus, the mm operation raise errors because of the dtype of inputs are inconsistent
# thus, the mm operation raise errors because of the dtype of inputs are inconsistent before.
return grad.mm(b.t()), a.t().mm(grad)
x = paddle.rand([10, 10])
......@@ -1182,5 +1182,39 @@ class TestPyLayerWithAmp(unittest.TestCase):
loss.backward()
class TestAmpWithHook(unittest.TestCase):
def test_hook_change_dtype(self):
with paddle.fluid.dygraph.guard():
v = paddle.rand([3, 3])
v.stop_gradient = False
def foo(grad):
print('grad', grad, grad.dtype) # grad's dtype is float32
res = paddle.mm(grad, grad) # mm runs in fp16
print('res', res, res.dtype) # res's dtype is float16
return res
v.register_hook(foo)
with paddle.amp.auto_cast():
a = paddle.mm(v, v)
loss = a.sum()
self.assertRaises(RuntimeError, loss.backward)
def test_hook_change_place(self):
with paddle.fluid.dygraph.guard():
v = paddle.rand([3, 3])
v.stop_gradient = False
def foo(grad):
res = grad.cpu() # change place
return res
v.register_hook(foo)
with paddle.amp.auto_cast():
a = paddle.mm(v, v)
loss = a.sum()
self.assertRaises(RuntimeError, loss.backward)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册