未验证 提交 2778fcd9 编写于 作者: W Wilber 提交者: GitHub

fix shape api (#37412)

上级 0fa96e91
......@@ -11384,6 +11384,8 @@ def shape(input):
import paddle.fluid as fluid
import numpy as np
import paddle
paddle.enable_static()
inputs = fluid.data(name="x", shape=[3, 100, 100], dtype="float32")
output = fluid.layers.shape(inputs)
......@@ -11396,6 +11398,11 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
if in_dygraph_mode():
out = _C_ops.shape(input)
out.stop_gradient = True
return out
check_variable_and_dtype(input, 'input', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128'
......@@ -11403,7 +11410,10 @@ def shape(input):
helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type='shape', inputs={'Input': input}, outputs={'Out': out})
type='shape',
inputs={'Input': input},
outputs={'Out': out},
stop_gradient=True)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册