未验证 提交 f2fe7c6e 编写于 作者: C Charles-hit 提交者: GitHub

fix static prim api code gen (#51445)

* fix static prim api code gen

* fix static prim api gen
上级 aa32aab2
...@@ -28,7 +28,7 @@ set(prim_api_h_path ...@@ -28,7 +28,7 @@ set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
) )
set(static_prim_api_template_path set(static_prim_api_template_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.j2"
) )
set(eager_prim_api_gen_file set(eager_prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/eager_gen.py) ${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/eager_gen.py)
......
{% from "utils.tpl" import static_prim_api %} {% from "utils.cc.j2" import static_prim_api %}
// Generated by /paddle/fluid/prim/api/auto_code_generated/static_gen.py. // Generated by /paddle/fluid/prim/api/auto_code_generated/static_gen.py.
// DO NOT EDIT! // DO NOT EDIT!
......
...@@ -75,7 +75,7 @@ std::tuple<{{sequence('', '', ', ', names)}}> ...@@ -75,7 +75,7 @@ std::tuple<{{sequence('', '', ', ', names)}}>
{%- macro static_prim_api_input_optional(input) -%} {%- macro static_prim_api_input_optional(input) -%}
{%- if input.typename=='Tensor[]' -%} {#- render the input of type paddle::optional<std::Vector<Tensor>> -#} {%- if input.typename=='Tensor[]' -%} {#- render the input of type paddle::optional<std::Vector<Tensor>> -#}
if ({{input.name}}) { if ({{input.name}}) {
std::vector<std::string> {{input.name}}_names; std::vector<std::string> {{input.name}}_names({{input.name}}.get().size());
std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_names.begin(), [](const Tensor& t) { std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_names.begin(), [](const Tensor& t) {
return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name(); return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name();
}); });
...@@ -91,7 +91,7 @@ if ({{input.name}}) { ...@@ -91,7 +91,7 @@ if ({{input.name}}) {
{%- macro static_prim_api_input_without_optional(input) -%} {%- macro static_prim_api_input_without_optional(input) -%}
{%- if input.typename is tensor_sequence -%} {#- render the input of type std::Vector<Tensor> -#} {%- if input.typename is tensor_sequence -%} {#- render the input of type std::Vector<Tensor> -#}
std::vector<std::string> {{input.name}}_names; std::vector<std::string> {{input.name}}_names({{input.name}}.size());;
std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_names.begin(), [](const Tensor& t) { std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_names.begin(), [](const Tensor& t) {
return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name(); return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name();
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册