未验证 提交 ac5cc136 编写于 作者: F From00 提交者: GitHub

Add yaml config for pool2d (#40563)

* Add yaml config for pool2d

* Fix CI error

* Fix code format error
上级 8ffcf596
......@@ -28,6 +28,7 @@ namespace = ""
yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
......@@ -212,7 +213,8 @@ def ParseYamlArgs(string):
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys()
assert arg_type in yaml_types_mapping.keys(
), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name)
......@@ -247,7 +249,8 @@ def ParseYamlReturns(string):
else:
ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys()
assert ret_type in yaml_types_mapping.keys(
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type
......
......@@ -24,7 +24,7 @@ atype_to_parsing_function = {
"long": "CastPyArg2Long",
"int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float",
"string": "CastPyArg2String",
"std::string": "CastPyArg2String",
"std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints",
"std::vector<long>": "CastPyArg2Longs",
......
......@@ -35,6 +35,12 @@ final_state_name_mapping = {
"x": "X",
"out": "Out",
},
"pool2d": {
"final_op_name": "final_state_pool2d",
"x": "X",
"kernel_size": "ksize",
"out": "Out",
},
"abs": {
"final_op_name": "final_state_abs",
"x": "X",
......
......@@ -141,6 +141,14 @@
output : Tensor
invoke : full_like(x, 1, dtype, place)
- api : pool2d
args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out)
infer_meta :
func : PoolInferMeta
kernel:
func : pool2d
- api : reshape
args : (Tensor x, ScalarArray shape)
output : Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册