未验证 提交 47d1d5af 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Support string type attr in yaml (#40218)

* support str attr in yaml

* fix bug
上级 061044a0
...@@ -25,10 +25,10 @@ core_ops_args_type_info = {} ...@@ -25,10 +25,10 @@ core_ops_args_type_info = {}
yaml_types_mapping = { yaml_types_mapping = {
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>', 'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>', 'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>',
......
...@@ -121,7 +121,7 @@ ...@@ -121,7 +121,7 @@
backward : matmul_grad backward : matmul_grad
- api : mean - api : mean
args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) args : (Tensor x, int64[] axis={}, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : MeanInferMeta func : MeanInferMeta
...@@ -181,7 +181,7 @@ ...@@ -181,7 +181,7 @@
func : subtract func : subtract
- api : sum - api : sum
args : (Tensor x, int64_t[] axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) args : (Tensor x, int64[] axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : SumInferMeta func : SumInferMeta
......
...@@ -89,18 +89,20 @@ class BaseAPI(object): ...@@ -89,18 +89,20 @@ class BaseAPI(object):
attr_types_map = { attr_types_map = {
'ScalarArray': 'const ScalarArray&', 'ScalarArray': 'const ScalarArray&',
'Scalar': 'const Scalar&', 'Scalar': 'const Scalar&',
'uint8': 'uint8_t',
'int': 'int', 'int': 'int',
'int32_t': 'int32_t', 'int32': 'int32_t',
'int64_t': 'int64_t', 'int64': 'int64_t',
'long': 'long', 'long': 'long',
'size_t': 'size_t', 'size_t': 'size_t',
'float': 'float', 'float': 'float',
'double': 'double', 'double': 'double',
'bool': 'bool', 'bool': 'bool',
'str': 'const std::string&',
'Backend': 'Backend', 'Backend': 'Backend',
'DataLayout': 'DataLayout', 'DataLayout': 'DataLayout',
'DataType': 'DataType', 'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&', 'int64[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&', 'int[]': 'const std::vector<int>&',
'long[]': 'const std::vector<int64_t>&' 'long[]': 'const std::vector<int64_t>&'
} }
...@@ -110,8 +112,8 @@ class BaseAPI(object): ...@@ -110,8 +112,8 @@ class BaseAPI(object):
'ScalarArray': 'const paddle::optional<ScalarArray>&', 'ScalarArray': 'const paddle::optional<ScalarArray>&',
'Scalar': 'const paddle::optional<Scalar>&', 'Scalar': 'const paddle::optional<Scalar>&',
'int': 'paddle::optional<int>', 'int': 'paddle::optional<int>',
'int32_t': 'paddle::optional<int32_t>', 'int32': 'paddle::optional<int32_t>',
'int64_t': 'paddle::optional<int64_t>', 'int64': 'paddle::optional<int64_t>',
'size_t': 'paddle::optional<size_t>', 'size_t': 'paddle::optional<size_t>',
'float': 'paddle::optional<float>', 'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>', 'double': 'paddle::optional<double>',
...@@ -119,7 +121,7 @@ class BaseAPI(object): ...@@ -119,7 +121,7 @@ class BaseAPI(object):
'Backend': 'paddle::optional<Backend>', 'Backend': 'paddle::optional<Backend>',
'DataLayout': 'paddle::optional<DataLayout>', 'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>', 'DataType': 'paddle::optional<DataType>',
'int64_t[]': 'paddle::optional<std::vector<int64_t>>', 'int64[]': 'paddle::optional<std::vector<int64_t>>',
'int[]': 'paddle::optional<std::vector<int>>' 'int[]': 'paddle::optional<std::vector<int>>'
} }
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
invoke : to_dense_impl(x, backend) invoke : to_dense_impl(x, backend)
- sparse_api : to_sparse_coo - sparse_api : to_sparse_coo
args : (Tensor x, Backend backend, int64_t sparse_dim) args : (Tensor x, Backend backend, int64 sparse_dim)
output : Tensor(out@SparseCooTensor) output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, backend, sparse_dim) invoke : to_sparse_coo_impl(x, backend, sparse_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册