未验证 提交 43876e8b 编写于 作者: L Leo Chen 提交者: GitHub

make stop_gradient=True for random op in static graph (#33959)

上级 3629bf4f
...@@ -74,6 +74,7 @@ def bernoulli(x, name=None): ...@@ -74,6 +74,7 @@ def bernoulli(x, name=None):
dtype=x.dtype) # maybe set out to int32 ? dtype=x.dtype) # maybe set out to int32 ?
helper.append_op( helper.append_op(
type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={}) type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={})
out.stop_gradient = True
return out return out
...@@ -143,6 +144,7 @@ def multinomial(x, num_samples=1, replacement=False, name=None): ...@@ -143,6 +144,7 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs={'num_samples': num_samples, attrs={'num_samples': num_samples,
'replacement': replacement}) 'replacement': replacement})
out.stop_gradient = True
return out return out
...@@ -514,6 +516,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -514,6 +516,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
helper.append_op( helper.append_op(
type="uniform_random", inputs=inputs, attrs=attrs, type="uniform_random", inputs=inputs, attrs=attrs,
outputs={"Out": out}) outputs={"Out": out})
out.stop_gradient = True
return out return out
...@@ -615,6 +618,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -615,6 +618,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs) type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
out.stop_gradient = True
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册