未验证 提交 04087012 编写于 作者: Z zyfncg 提交者: GitHub

Scalar support marking data_type in yaml (#40867)

* Scalar support marking data_type in yaml

* fix code-gene bug
上级 0f5e90a2
...@@ -30,15 +30,19 @@ core_ops_args_info = {} ...@@ -30,15 +30,19 @@ core_ops_args_info = {}
core_ops_args_type_info = {} core_ops_args_type_info = {}
yaml_types_mapping = { yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \ 'str' : 'std::string', \
'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>', 'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>', 'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>',
'Scalar' : 'paddle::experimental::Scalar', 'Scalar' : 'paddle::experimental::Scalar',
'Scalar(int)' : 'paddle::experimental::Scalar',
'Scalar(int64_t)' : 'paddle::experimental::Scalar',
'Scalar(float)' : 'paddle::experimental::Scalar',
'Scalar(double)' : 'paddle::experimental::Scalar',
'ScalarArray' : 'paddle::experimental::ScalarArray' 'ScalarArray' : 'paddle::experimental::ScalarArray'
} }
...@@ -254,8 +258,8 @@ def ParseYamlForward(args_str, returns_str): ...@@ -254,8 +258,8 @@ def ParseYamlForward(args_str, returns_str):
fargs = r'(.*?)' fargs = r'(.*?)'
wspace = r'\s*' wspace = r'\s*'
args_pattern = f'\({fargs}\)' args_pattern = f'^\({fargs}\)$'
args_str = re.search(args_pattern, args_str).group(1) args_str = re.search(args_pattern, args_str.strip()).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str) inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturns(returns_str) returns_list = ParseYamlReturns(returns_str)
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
- api : concat - api : concat
args : (Tensor[] x, Scalar axis) args : (Tensor[] x, Scalar(int64_t) axis)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ConcatInferMeta func : ConcatInferMeta
...@@ -123,7 +123,7 @@ ...@@ -123,7 +123,7 @@
backward : matmul_grad backward : matmul_grad
- api : mean - api : mean
args : (Tensor x, int64[] axis={}, bool keep_dim=false) args : (Tensor x, int64_t[] axis={}, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ReduceInferMeta func : ReduceInferMeta
...@@ -198,7 +198,7 @@ ...@@ -198,7 +198,7 @@
func : sotfmax func : sotfmax
- api : split - api : split
args : (Tensor x, ScalarArray num_or_sections, Scalar axis) args : (Tensor x, ScalarArray num_or_sections, Scalar(int) axis)
output : Tensor[] output : Tensor[]
invoke : split_impl(x, num_or_sections, axis) invoke : split_impl(x, num_or_sections, axis)
...@@ -212,7 +212,7 @@ ...@@ -212,7 +212,7 @@
backward : subtract_grad backward : subtract_grad
- api : sum - api : sum
args : (Tensor x, int64[] axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) args : (Tensor x, int64_t[] axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : SumInferMeta func : SumInferMeta
...@@ -227,7 +227,7 @@ ...@@ -227,7 +227,7 @@
- api : one_hot - api : one_hot
args : (Tensor x, Scalar num_classes) args : (Tensor x, Scalar(int) num_classes)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : OneHotInferMeta func : OneHotInferMeta
......
...@@ -89,10 +89,13 @@ class BaseAPI(object): ...@@ -89,10 +89,13 @@ class BaseAPI(object):
attr_types_map = { attr_types_map = {
'ScalarArray': 'const ScalarArray&', 'ScalarArray': 'const ScalarArray&',
'Scalar': 'const Scalar&', 'Scalar': 'const Scalar&',
'uint8': 'uint8_t', 'Scalar(int)': 'const Scalar&',
'Scalar(int64_t)': 'const Scalar&',
'Scalar(float)': 'const Scalar&',
'Scalar(dobule)': 'const Scalar&',
'int': 'int', 'int': 'int',
'int32': 'int32_t', 'int32_t': 'int32_t',
'int64': 'int64_t', 'int64_t': 'int64_t',
'long': 'long', 'long': 'long',
'size_t': 'size_t', 'size_t': 'size_t',
'float': 'float', 'float': 'float',
...@@ -102,27 +105,21 @@ class BaseAPI(object): ...@@ -102,27 +105,21 @@ class BaseAPI(object):
'Place': 'Place', 'Place': 'Place',
'DataLayout': 'DataLayout', 'DataLayout': 'DataLayout',
'DataType': 'DataType', 'DataType': 'DataType',
'int64[]': 'const std::vector<int64_t>&', 'int64_t[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&', 'int[]': 'const std::vector<int>&'
'long[]': 'const std::vector<int64_t>&'
} }
optional_types_trans = { optional_types_trans = {
'Tensor': 'const paddle::optional<Tensor>&', 'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&', 'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'ScalarArray': 'const paddle::optional<ScalarArray>&',
'Scalar': 'const paddle::optional<Scalar>&',
'int': 'paddle::optional<int>', 'int': 'paddle::optional<int>',
'int32': 'paddle::optional<int32_t>', 'int32_t': 'paddle::optional<int32_t>',
'int64': 'paddle::optional<int64_t>', 'int64_t': 'paddle::optional<int64_t>',
'size_t': 'paddle::optional<size_t>',
'float': 'paddle::optional<float>', 'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>', 'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>', 'bool': 'paddle::optional<bool>',
'Place': 'paddle::optional<Place>', 'Place': 'paddle::optional<Place>',
'DataLayout': 'paddle::optional<DataLayout>', 'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>', 'DataType': 'paddle::optional<DataType>'
'int64[]': 'paddle::optional<std::vector<int64_t>>',
'int[]': 'paddle::optional<std::vector<int>>'
} }
args_declare_str = "" args_declare_str = ""
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
invoke : to_dense_impl(x) invoke : to_dense_impl(x)
- api : to_sparse_coo - api : to_sparse_coo
args : (Tensor x, int64 sparse_dim) args : (Tensor x, int64_t sparse_dim)
output : Tensor(out@SparseCooTensor) output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, sparse_dim) invoke : to_sparse_coo_impl(x, sparse_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册