未验证 提交 0dafbb03 编写于 作者: Z zyfncg 提交者: GitHub

Remove auto to_pascal_case for args in op generator (#44350)

* remove auto to_pascal_case for args in op generator

* fix yaml config
上级 270f25e9
...@@ -24,7 +24,7 @@ repos: ...@@ -24,7 +24,7 @@ repos:
files: (?!.*third_party)^.*$ | (?!.*book)^.*$ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer - id: end-of-file-fixer
- id: sort-simple-yaml - id: sort-simple-yaml
files: (api|backward)\.yaml$ files: (api|backward|api_[a-z_]+)\.yaml$
- repo: local - repo: local
hooks: hooks:
- id: clang-format - id: clang-format
......
- api : atan2 - api : atan2
inputs : inputs :
x : X1 {x : X1, y : X2}
y : X2
outputs : outputs :
out : Out out : Out
- api : bernoulli
inputs :
x : X
outputs :
out : Out
- api : cholesky
inputs :
x : X
outputs :
out : Out
- api : cholesky_solve
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : cross - api : cross
inputs : {x : X, y : Y} inputs :
{x : X, y : Y}
attrs : attrs :
axis : dim axis : dim
outputs : outputs :
...@@ -26,17 +53,50 @@ ...@@ -26,17 +53,50 @@
outputs : outputs :
out : Out out : Out
- api : digamma
inputs :
x : X
outputs :
out : Out
- api : dist
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : dot
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : erf
inputs :
x : X
outputs :
out : Out
- api : mv
inputs :
{x : X, vec : Vec}
outputs :
out : Out
- api : poisson
inputs :
x : X
outputs :
out : Out
- api : trace - api : trace
inputs : inputs :
x : Input x : Input
outputs : outputs :
out : Out out : Out
- api : conv2d - api : trunc
extra : inputs :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, x : X
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, outputs :
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false, out : Out
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
...@@ -79,9 +79,9 @@ def to_sr_output_type(s): ...@@ -79,9 +79,9 @@ def to_sr_output_type(s):
# -------------- transform argument names from yaml to opmaker ------------ # -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s): def to_opmaker_name(s):
if s.endswith("_grad"): if s.endswith("_grad"):
return 'GradVarName("{}")'.format(to_pascal_case(s[:-5])) return 'GradVarName("{}")'.format(s[:-5])
else: else:
return '"{}"'.format(to_pascal_case(s)) return '"{}"'.format(s)
def to_opmaker_name_cstr(s): def to_opmaker_name_cstr(s):
......
...@@ -358,15 +358,15 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -358,15 +358,15 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
input_orig_names, output_orig_names) %}{# inline #} input_orig_names, output_orig_names) %}{# inline #}
{% if name in input_names %} {% if name in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input("{{name_in_forward_orig | to_pascal_case}}") Input("{{name_in_forward_orig}}")
{%- elif name in output_names %} {%- elif name in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name)]%} {% set name_in_forward_orig = output_orig_names[output_names.index(name)]%}
Output("{{name | to_pascal_case}}") Output("{{name}}")
{%- elif name.endswith("_grad") %}{# output grad#} {%- elif name.endswith("_grad") %}{# output grad#}
{% set name_in_forward = name[:-5] %} {% set name_in_forward = name[:-5] %}
{% if name_in_forward in output_names %} {% if name_in_forward in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %} {% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %}
OutputGrad("{{name_in_forward_orig | to_pascal_case}}") OutputGrad("{{name_in_forward_orig}}")
{%- endif %} {%- endif %}
{%- endif %} {%- endif %}
{%- endmacro %} {%- endmacro %}
...@@ -376,11 +376,11 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}") ...@@ -376,11 +376,11 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
{% if name[:-5] in input_names %} {% if name[:-5] in input_names %}
{% set name_in_forward = name[:-5] %} {% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name[:-5] | to_pascal_case}}") InputGrad("{{name[:-5]}}")
{%- elif (name | to_input_name) in input_names %} {%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %} {% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name | to_input_name | to_pascal_case}}") InputGrad("{{name | to_input_name}}")
{%- endif %} {%- endif %}
{%- endmacro %} {%- endmacro %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册