未验证 提交 f2fe7c6e 编写于 作者: C Charles-hit 提交者: GitHub

fix static prim api code gen (#51445)

* fix static prim api code gen

* fix static prim api gen
上级 aa32aab2
......@@ -28,7 +28,7 @@ set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
)
set(static_prim_api_template_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl"
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.j2"
)
set(eager_prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/eager_gen.py)
......
{% from "utils.tpl" import static_prim_api %}
{% from "utils.cc.j2" import static_prim_api %}
// Generated by /paddle/fluid/prim/api/auto_code_generated/static_gen.py.
// DO NOT EDIT!
......
......@@ -75,7 +75,7 @@ std::tuple<{{sequence('', '', ', ', names)}}>
{%- macro static_prim_api_input_optional(input) -%}
{%- if input.typename=='Tensor[]' -%} {#- render the input of type paddle::optional<std::Vector<Tensor>> -#}
if ({{input.name}}) {
std::vector<std::string> {{input.name}}_names;
std::vector<std::string> {{input.name}}_names({{input.name}}.get().size());
std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_names.begin(), [](const Tensor& t) {
return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name();
});
......@@ -91,7 +91,7 @@ if ({{input.name}}) {
{%- macro static_prim_api_input_without_optional(input) -%}
{%- if input.typename is tensor_sequence -%} {#- render the input of type std::Vector<Tensor> -#}
std::vector<std::string> {{input.name}}_names;
std::vector<std::string> {{input.name}}_names({{input.name}}.size());;
std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_names.begin(), [](const Tensor& t) {
return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name();
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册