未验证 提交 e3ee2ad8 编写于 作者: X xiongkun 提交者: GitHub

sync stop_gradient in ParamBase. Fix the Different Behavior between Eval and Train (#42899)

上级 fba94b9f
...@@ -101,8 +101,11 @@ def monkey_patch_varbase(): ...@@ -101,8 +101,11 @@ def monkey_patch_varbase():
# Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph. # Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph.
# It will fail. So, for propery that different between dynamic and static graph, should not getattr(self, attr, None). # It will fail. So, for propery that different between dynamic and static graph, should not getattr(self, attr, None).
attr_not_need_keys = ['grad', 'T', 'place', '_place_str'] attr_not_need_keys = ['grad', 'T', 'place', '_place_str']
param_keys = ['stop_gradient', 'trainable']
if isinstance(self, (ParamBase, EagerParamBase)): if isinstance(self, (ParamBase, EagerParamBase)):
attr_kwargs = self.__dict__.copy() attr_kwargs = self.__dict__.copy()
for key in param_keys:
attr_kwargs[key] = getattr(self, key)
else: else:
attr_names = [] attr_names = []
for name in dir(self): for name in dir(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册