未验证 提交 21957476 编写于 作者: W Wilber 提交者: GitHub

shape api should not backward (#37340)

* shape api should not backward

* fix stop_gradient

* update

* update doc
上级 6c4621f1
...@@ -11384,6 +11384,8 @@ def shape(input): ...@@ -11384,6 +11384,8 @@ def shape(input):
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import paddle
paddle.enable_static()
inputs = fluid.data(name="x", shape=[3, 100, 100], dtype="float32") inputs = fluid.data(name="x", shape=[3, 100, 100], dtype="float32")
output = fluid.layers.shape(inputs) output = fluid.layers.shape(inputs)
...@@ -11396,6 +11398,11 @@ def shape(input): ...@@ -11396,6 +11398,11 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)] print(res) # [array([ 3, 100, 100], dtype=int32)]
""" """
if in_dygraph_mode():
out = _C_ops.shape(input)
out.stop_gradient = True
return out
check_variable_and_dtype(input, 'input', [ check_variable_and_dtype(input, 'input', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128' 'complex128'
...@@ -11403,7 +11410,10 @@ def shape(input): ...@@ -11403,7 +11410,10 @@ def shape(input):
helper = LayerHelper('shape', **locals()) helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32') out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op( helper.append_op(
type='shape', inputs={'Input': input}, outputs={'Out': out}) type='shape',
inputs={'Input': input},
outputs={'Out': out},
stop_gradient=True)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册