未验证 提交 0dafbb03 编写于 作者: Z zyfncg 提交者: GitHub

Remove auto to_pascal_case for args in op generator (#44350)

* remove auto to_pascal_case for args in op generator

* fix yaml config
上级 270f25e9
......@@ -24,7 +24,7 @@ repos:
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- id: sort-simple-yaml
files: (api|backward)\.yaml$
files: (api|backward|api_[a-z_]+)\.yaml$
- repo: local
hooks:
- id: clang-format
......
- api : atan2
inputs :
x : X1
y : X2
{x : X1, y : X2}
outputs :
out : Out
- api : bernoulli
inputs :
x : X
outputs :
out : Out
- api : cholesky
inputs :
x : X
outputs :
out : Out
- api : cholesky_solve
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : cross
inputs : {x : X, y : Y}
inputs :
{x : X, y : Y}
attrs :
axis : dim
outputs :
......@@ -26,17 +53,50 @@
outputs :
out : Out
- api : digamma
inputs :
x : X
outputs :
out : Out
- api : dist
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : dot
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : erf
inputs :
x : X
outputs :
out : Out
- api : mv
inputs :
{x : X, vec : Vec}
outputs :
out : Out
- api : poisson
inputs :
x : X
outputs :
out : Out
- api : trace
inputs :
x : Input
outputs :
out : Out
- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
- api : trunc
inputs :
x : X
outputs :
out : Out
......@@ -79,9 +79,9 @@ def to_sr_output_type(s):
# -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s):
if s.endswith("_grad"):
return 'GradVarName("{}")'.format(to_pascal_case(s[:-5]))
return 'GradVarName("{}")'.format(s[:-5])
else:
return '"{}"'.format(to_pascal_case(s))
return '"{}"'.format(s)
def to_opmaker_name_cstr(s):
......
......@@ -358,15 +358,15 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
input_orig_names, output_orig_names) %}{# inline #}
{% if name in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input("{{name_in_forward_orig | to_pascal_case}}")
Input("{{name_in_forward_orig}}")
{%- elif name in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name)]%}
Output("{{name | to_pascal_case}}")
Output("{{name}}")
{%- elif name.endswith("_grad") %}{# output grad#}
{% set name_in_forward = name[:-5] %}
{% if name_in_forward in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %}
OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
OutputGrad("{{name_in_forward_orig}}")
{%- endif %}
{%- endif %}
{%- endmacro %}
......@@ -376,11 +376,11 @@ OutputGrad("{{name_in_forward_orig | to_pascal_case}}")
{% if name[:-5] in input_names %}
{% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name[:-5] | to_pascal_case}}")
InputGrad("{{name[:-5]}}")
{%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name | to_input_name | to_pascal_case}}")
InputGrad("{{name | to_input_name}}")
{%- endif %}
{%- endmacro %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册