未验证 提交 d6a38532 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[prim] support set output dtype for autogen (#52475)

上级 f6d4ae3d
......@@ -46,5 +46,6 @@
- cos
- where
- split
- reshape
- erf
- tanh
......@@ -7,6 +7,7 @@
{%- set attrs = api.attrs -%}
{%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{%- set output_dtype = static_prim_api_output_dtype(api.inputs, api.attrs) -%}
{{static_prim_api_sig(phi_name, inputs, outputs, attrs)}} {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
......@@ -16,7 +17,7 @@
{{static_prim_api_input(input)}}
{% endfor %}
{% for output in outputs %}
{{static_prim_api_output(output)}}
{{static_prim_api_output(output, output_dtype)}}
{% endfor %}
{% for attr in attrs %}
{{static_prim_api_attr(attr)}}
......@@ -102,33 +103,33 @@ op->SetInput("{{input.fluid_name | to_pascal}}", {std::static_pointer_cast<prim:
{%- endmacro -%}
{% macro static_prim_api_output(output) %}
{% macro static_prim_api_output(output, dtype) %}
{%- if output.optional -%}
{{static_prim_api_output_optional(output)}}
{{static_prim_api_output_optional(output, dtype)}}
{%- else -%}
{{static_prim_api_output_without_optional(output)}}
{{static_prim_api_output_without_optional(output, dtype)}}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_output_without_optional(output) -%}
{%- macro static_prim_api_output_without_optional(output, dtype) -%}
{%- if output.typename is tensor_sequence -%} {#- render the output of type std::Vector<Tensor> -#}
std::vector<Tensor> {{output.name}};
std::vector<std::string> {{output.name}}_names;
for (size_t i=0; i<{{output.size}}; i++) {
auto tmp = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
auto tmp = empty<DescTensor>({}, {{dtype}}, paddle::Place());
{{output.name}}.push_back(tmp);
{{output.name}}_names.push_back(std::static_pointer_cast<prim::DescTensor>(tmp.impl())->Name());
}
op->SetOutput("{{output.fluid_name | to_pascal}}", {{output.name}}_names);
{%- else -%}
auto {{output.name}} = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
auto {{output.name}} = empty<DescTensor>({}, {{dtype}}, paddle::Place());
op->SetOutput("{{output.fluid_name | to_pascal}}", {std::static_pointer_cast<prim::DescTensor>({{output.name}}.impl())->Name()});
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_output_optional(output) -%}
{%- macro static_prim_api_output_optional(output, dtype) -%}
// TODO(cxxly): Render optional output
{%- endmacro -%}
......@@ -173,3 +174,27 @@ paddle::framework::TransToProtoVarType({{src_name}})
{%- macro sequence(lsymbol, rsymbol, delimiter, items) -%}
{{lsymbol}}{%- for item in items -%}{{item}}{{delimiter if not loop.last else "" }}{%- endfor -%}{{rsymbol}}
{%- endmacro -%}
{%- macro static_prim_api_output_dtype(inputs, attrs) -%}
{%- set is_set = [] -%} {#- why not use boolean, ref: https://stackoverflow.com/questions/17925674/jinja2-local-global-variable -#}
{%- if not is_set -%} {#- use DataType attr as default output dtype -#}
{%- for attr in attrs -%}
{%- if attr.typename is datatype -%}
{{attr.name}}
{%- do is_set.append(1) -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if not is_set -%} {#- use first input named x dtype as default output dtype -#}
{%- for input in inputs -%}
{%- if input.typename == 'Tensor' and input.name == 'x' -%}
{{input.name}}.dtype()
{%- do is_set.append(1) -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if not is_set -%} {#- use fp32 as default output dtype -#}
phi::DataType::FLOAT32
{%- endif -%}
{%- endmacro -%}
......@@ -33,11 +33,5 @@ Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype);
}
template <>
Tensor reshape<Tensor>(const Tensor& x, const IntArray& shape) {
VLOG(4) << "Eager Prim API reshape_ad_func call";
return ::reshape_ad_func(x, shape);
}
} // namespace prim
} // namespace paddle
......@@ -38,8 +38,5 @@ Tensor full(const IntArray& shape,
template <typename T>
Tensor cast(const Tensor& x, DataType dtype);
template <typename T>
Tensor reshape(const Tensor& x, const IntArray& shape);
} // namespace prim
} // namespace paddle
......@@ -127,22 +127,5 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
return out;
}
template <>
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("reshape2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("shape", unsafe_vector_cast<int64_t, int>(shape.GetData()));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
} // namespace prim
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册