未验证 提交 775fb43a 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Filter tensor type for int_array and scalar input in composite rule (#51208)

上级 60d04fa5
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
def fn(x, shape):
out = paddle.expand(x, shape=shape)
return out
class TestIntarrayInput(unittest.TestCase):
"""This case is set to test int_array input process during composite rule."""
def test_non_tensor_input(self):
core._set_prim_all_enabled(True)
np_data = np.random.random([3, 4]).astype("float32")
tensor_data = paddle.to_tensor(np_data)
net = paddle.jit.to_static(fn)
_ = net(tensor_data, shape=[2, 3, 4]).numpy()
core._set_prim_all_enabled(False)
def test_error_input(self):
"""In composite rules, tensor shape is not supported in int_array input"""
core._set_prim_all_enabled(True)
np_data = np.random.random([3, 4]).astype("float32")
tensor_data = paddle.to_tensor(np_data)
shape = paddle.to_tensor([2, 3, 4])
net = paddle.jit.to_static(fn)
with self.assertRaises(ValueError):
_ = net(tensor_data, shape).numpy()
core._set_prim_all_enabled(False)
if __name__ == '__main__':
unittest.main()
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
import functools import functools
import operator import operator
import paddle.framework.dtype as dtypes
from paddle.fluid import core from paddle.fluid import core
from .primitives import * # noqa: F403 from .primitives import * # noqa: F403
...@@ -361,7 +360,6 @@ def fill_any_like(x, fill_value, dtype, place=None): ...@@ -361,7 +360,6 @@ def fill_any_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like.""" """define composite rule of op full_like."""
"""op name: full_like op type name: fill_any_like.""" """op name: full_like op type name: fill_any_like."""
"""arg place is not used, add it here to keep same as python api.""" """arg place is not used, add it here to keep same as python api."""
dtype = dtypes.dtype(dtype)
val = full(x.shape, fill_value, dtype) val = full(x.shape, fill_value, dtype)
return val return val
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import typing import typing
import paddle import paddle
import paddle.framework.dtype as dtypes
from paddle.fluid import framework as framework from paddle.fluid import framework as framework
from .phi_ops_map import op_info, op_map from .phi_ops_map import op_info, op_map
...@@ -159,15 +160,52 @@ def _solve_arg(item): ...@@ -159,15 +160,52 @@ def _solve_arg(item):
return arg_type.strip(), arg_name.strip() return arg_type.strip(), arg_name.strip()
def _get_attr_value(op, arg_type, arg_name):
op_content = op_map[op.type]
if "attrs" in op_content.keys() and arg_name in op_content["attrs"].keys():
arg_name = op_content["attrs"][arg_name]
# Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded.
if arg_name not in op.attr_names:
return None
else:
if arg_type == "DataType":
return dtypes.dtype(op.attr(arg_name))
return op.attr(arg_name)
def _get_args_values(op, phi_name): def _get_args_values(op, phi_name):
"get attrs' values for api args' values" "get attrs' values for api args' values"
args = op_info[phi_name] args = op_info[phi_name]
args_list = args["args"].split(",") args_list = args["args"].split(",")
inputs = [] inputs = []
attrs = [] attrs = []
for item in args_list: for item in args_list:
arg_type, arg_name = _solve_arg(item) arg_type, arg_name = _solve_arg(item)
op_content = op_map[op.type] op_content = op_map[op.type]
# IntArray and Scalar are special cases which may cause dynamic shape. In these case, tensor-relative types are removed in composite op.
if arg_type in ("IntArray", "Scalar"):
tensor_key = "int_array" if arg_type == "IntArray" else "scalar"
if op_content.get(tensor_key):
tensor_content = op_content[tensor_key].get(arg_name)
if not tensor_content:
raise ValueError(
f'No value found for {arg_name} of {arg_type} type for operator {op.type}.'
)
for item in ("tensor_name", "tensors_name"):
# name of intarray may differ from operator arg_name
arg_name_new = tensor_content.get(item)
if (
arg_name_new is not None
and arg_name_new in op.input_names
and get_var_block(op.block, op.input(arg_name_new))
):
raise ValueError(
f"Tensor type of {arg_type} is not supported in composite op. Please set other type value of input arg {arg_name_new} for operator {op.type}."
)
if arg_type in ("Tensor", "Tensor[]"): if arg_type in ("Tensor", "Tensor[]"):
# assume Tensor type must belong to inputs # assume Tensor type must belong to inputs
if ( if (
...@@ -178,19 +216,8 @@ def _get_args_values(op, phi_name): ...@@ -178,19 +216,8 @@ def _get_args_values(op, phi_name):
else: else:
inputs.append(arg_name) inputs.append(arg_name)
else: else:
op_content = op_map[op.type] attr_value = _get_attr_value(op, arg_type, arg_name)
if ( attrs.append(attr_value)
"attrs" in op_content.keys()
and arg_name in op_content["attrs"].keys()
):
arg_name = op_content["attrs"][arg_name]
# Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded.
if arg_name not in op.attr_names:
attrs.append(None)
else:
attrs.append(op.attr(arg_name))
return inputs, attrs return inputs, attrs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册