From 43876e8b8db95e1116395e50974712db414506a6 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 5 Jul 2021 19:04:05 +0800 Subject: [PATCH] make stop_gradient=True for random op in static graph (#33959) --- python/paddle/tensor/random.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 69a4634544..9ddf12ffb4 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -74,6 +74,7 @@ def bernoulli(x, name=None): dtype=x.dtype) # maybe set out to int32 ? helper.append_op( type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={}) + out.stop_gradient = True return out @@ -143,6 +144,7 @@ def multinomial(x, num_samples=1, replacement=False, name=None): outputs={'Out': out}, attrs={'num_samples': num_samples, 'replacement': replacement}) + out.stop_gradient = True return out @@ -514,6 +516,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): helper.append_op( type="uniform_random", inputs=inputs, attrs=attrs, outputs={"Out": out}) + out.stop_gradient = True return out @@ -615,6 +618,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): out = helper.create_variable_for_type_inference(dtype=dtype) helper.append_op( type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs) + out.stop_gradient = True return out -- GitLab