未验证 提交 1e5fec39 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Fix get var in prim when list of single tensor (#56114)

* fix get var in prim

* fix stack test case
上级 0c7fdda9
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import typing import typing
import paddle import paddle
...@@ -132,11 +133,13 @@ INT_DTYPE_2_STRING = { ...@@ -132,11 +133,13 @@ INT_DTYPE_2_STRING = {
} }
def get_var_block(block, names): def get_var_block(block, names, is_tensor_list=None):
assert isinstance(names, list) assert isinstance(names, list)
if len(names) == 0: if len(names) == 0:
return None return None
elif len(names) == 1: elif len(names) == 1:
if is_tensor_list:
return [block.var(names[0])]
return block.var(names[0]) return block.var(names[0])
else: else:
return [block.var(name) for name in names] return [block.var(name) for name in names]
...@@ -179,7 +182,7 @@ def _get_args_values(op, phi_name): ...@@ -179,7 +182,7 @@ def _get_args_values(op, phi_name):
"get attrs' values for api args' values" "get attrs' values for api args' values"
args = op_info[phi_name] args = op_info[phi_name]
args_list = args["args"].split(",") args_list = args["args"].split(",")
inputs = [] inputs = collections.OrderedDict()
attrs = [] attrs = []
for item in args_list: for item in args_list:
...@@ -212,9 +215,9 @@ def _get_args_values(op, phi_name): ...@@ -212,9 +215,9 @@ def _get_args_values(op, phi_name):
"inputs" in op_content.keys() "inputs" in op_content.keys()
and arg_name in op_content["inputs"].keys() and arg_name in op_content["inputs"].keys()
): ):
inputs.append(op_content["inputs"][arg_name]) inputs[op_content["inputs"][arg_name]] = arg_type
else: else:
inputs.append(arg_name) inputs[arg_name] = arg_type
else: else:
attr_value = _get_attr_value(op, arg_type, arg_name) attr_value = _get_attr_value(op, arg_type, arg_name)
attrs.append(attr_value) attrs.append(attr_value)
...@@ -237,8 +240,15 @@ def prepare_python_api_arguments(op): ...@@ -237,8 +240,15 @@ def prepare_python_api_arguments(op):
phi_name = op.type phi_name = op.type
inputs, attrs = _get_args_values(op, phi_name) inputs, attrs = _get_args_values(op, phi_name)
res = [] res = []
for item in inputs: for item, tensor_type in inputs.items():
if item in op.input_names: if item in op.input_names:
if tensor_type == "Tensor[]":
res.append(
get_var_block(
op.block, op.input(item), is_tensor_list=True
)
)
else:
res.append(get_var_block(op.block, op.input(item))) res.append(get_var_block(op.block, op.input(item)))
else: else:
# Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded. # Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded.
......
...@@ -375,5 +375,20 @@ class TestStackAPI_ZeroDim(unittest.TestCase): ...@@ -375,5 +375,20 @@ class TestStackAPI_ZeroDim(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestStackListOfSingleTensor(unittest.TestCase):
def setUp(self):
paddle.disable_static()
paddle.seed(2022)
self.x = [paddle.randn((4, 2, 6), dtype="float32")]
def test_list_single_tensor(self):
expect = paddle.stack(self.x)
paddle.fluid.core._set_prim_all_enabled(True)
st_model = paddle.jit.to_static(paddle.stack)
actual = st_model(self.x)
np.testing.assert_allclose(expect, actual)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册