未验证 提交 82cf1fad 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[prim] generate static prim api (#50315)

上级 14e45f6b
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import itertools import itertools
import re import re
from typing import Dict, List from typing import Dict, List, Sequence
from type_mapping import ( from type_mapping import (
attr_types_map, attr_types_map,
...@@ -80,6 +80,10 @@ def to_sr_output_type(s): ...@@ -80,6 +80,10 @@ def to_sr_output_type(s):
return sr_output_types_map[s] return sr_output_types_map[s]
def filter_intermediate(items: Sequence):
return tuple([item for item in items if not item.get('intermediate')])
# -------------- transform argument names from yaml to opmaker ------------ # -------------- transform argument names from yaml to opmaker ------------
def to_opmaker_name(s): def to_opmaker_name(s):
if s.endswith("_grad"): if s.endswith("_grad"):
......
...@@ -38,6 +38,14 @@ def is_scalar(s): ...@@ -38,6 +38,14 @@ def is_scalar(s):
return re.match(r"Scalar(\(\w+\))*", s) is not None return re.match(r"Scalar(\(\w+\))*", s) is not None
def is_intarray(s):
return s == 'IntArray'
def is_datatype(s):
return s == 'DataType'
def is_initializer_list(s): def is_initializer_list(s):
return s == "{}" return s == "{}"
...@@ -63,3 +71,7 @@ def supports_no_need_buffer(op): ...@@ -63,3 +71,7 @@ def supports_no_need_buffer(op):
if input["no_need_buffer"]: if input["no_need_buffer"]:
return True return True
return False return False
def is_tensor_list(s):
return s == 'Tensor[]'
...@@ -3,11 +3,16 @@ add_subdirectory(manual_prim) ...@@ -3,11 +3,16 @@ add_subdirectory(manual_prim)
add_subdirectory(generated_prim) add_subdirectory(generated_prim)
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(eager_prim_api DEPS generated_eager_prim_api manual_eager_prim_api)
cc_library(static_prim_api DEPS generated_static_prim_api
manual_static_prim_api)
cc_library( cc_library(
prim_api prim_api
SRCS all.cc SRCS all.cc
DEPS static_utils static_prim_api eager_prim_api eager_api) DEPS static_utils static_prim_api eager_prim_api eager_api)
else() else()
cc_library(static_prim_api DEPS generated_static_prim_api
manual_static_prim_api)
cc_library( cc_library(
prim_api prim_api
SRCS all.cc SRCS all.cc
......
- unsqueeze
- pow
- exp
- scale
- multiply
- matmul
- expand
- divide
- sum
- add
- abs
- assign
- concat
- elementwise_pow
- floor
- gather_nd
- log
- max
- maximum
- minimum
- prod
- roll
- scatter
- scatter_nd_add
- tile
...@@ -4,36 +4,73 @@ set(api_yaml_path ...@@ -4,36 +4,73 @@ set(api_yaml_path
set(legacy_api_yaml_path set(legacy_api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml" "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
) )
set(api_compat_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml")
set(api_prim_yaml_path "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/api.yaml")
set(api_version_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml")
set(tmp_eager_prim_api_cc_path set(tmp_eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc.tmp"
)
set(tmp_static_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/static_prim_api.cc.tmp"
) )
set(tmp_prim_api_h_path set(tmp_prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_prim_generated_api.h" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h.tmp"
) )
set(eager_prim_api_cc_path set(eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc"
) )
set(static_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/static_prim_api.cc"
)
set(prim_api_h_path set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
) )
set(prim_api_gen_file set(static_prim_api_template_path
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py) "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/template/static_prim_api.cc.tpl"
)
set(eager_prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/eager_gen.py)
set(static_prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/static_gen.py
)
message("prim api Code gen") message("Eager prim api code generator")
execute_process( execute_process(
WORKING_DIRECTORY WORKING_DIRECTORY
${CMAKE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated ${CMAKE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated
COMMAND COMMAND
${PYTHON_EXECUTABLE} ${prim_api_gen_file} --api_yaml_path ${PYTHON_EXECUTABLE} ${eager_prim_api_gen_file} --api_yaml_path
${legacy_api_yaml_path} ${api_yaml_path} --prim_api_header_path ${legacy_api_yaml_path} ${api_yaml_path} --prim_api_header_path
${tmp_prim_api_h_path} --eager_prim_api_source_path ${tmp_prim_api_h_path} --eager_prim_api_source_path
${tmp_eager_prim_api_cc_path} ${tmp_eager_prim_api_cc_path} --api_prim_yaml_path ${api_prim_yaml_path}
RESULT_VARIABLE _result) RESULT_VARIABLE _result)
if(${_result}) if(${_result})
message(FATAL_ERROR "prim api genrate failed, exiting.") message(FATAL_ERROR "Eager prim api generate failed, exiting.")
endif() endif()
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
${tmp_prim_api_h_path} ${prim_api_h_path}) ${tmp_prim_api_h_path} ${prim_api_h_path})
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
${tmp_eager_prim_api_cc_path} ${eager_prim_api_cc_path}) ${tmp_eager_prim_api_cc_path} ${eager_prim_api_cc_path})
message("copy tmp_xxx_prim_api to xxx_prim_api") message("copy tmp_xxx_prim_api to xxx_prim_api")
message("Static prim api code generator")
execute_process(
WORKING_DIRECTORY
${CMAKE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated
COMMAND
${PYTHON_EXECUTABLE} ${static_prim_api_gen_file} --api_phi_yaml_path
${api_yaml_path} --api_phi_legacy_yaml_path ${legacy_api_yaml_path}
--api_compat_yaml_path ${api_compat_yaml_path} --api_version_yaml_path
${api_version_yaml_path} --api_prim_yaml_path ${api_prim_yaml_path}
--template_path ${static_prim_api_template_path} --output_path
${tmp_static_prim_api_cc_path}
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "Static prim api generate failed, exiting.")
endif()
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_static_prim_api_cc_path}
${static_prim_api_cc_path})
message("copy tmp_xxx_prim_api to xxx_prim_api")
...@@ -55,7 +55,9 @@ using DataType = paddle::experimental::DataType; ...@@ -55,7 +55,9 @@ using DataType = paddle::experimental::DataType;
) )
def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path): def generate_api(
api_yaml_path, header_file_path, eager_prim_source_file_path, api_prim_path
):
apis = [] apis = []
for each_api_yaml in api_yaml_path: for each_api_yaml in api_yaml_path:
...@@ -76,8 +78,11 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path): ...@@ -76,8 +78,11 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
eager_prim_source_file.write(eager_source_include()) eager_prim_source_file.write(eager_source_include())
eager_prim_source_file.write(namespace[0]) eager_prim_source_file.write(namespace[0])
with open(api_prim_path, 'rt') as f:
api_prims = yaml.safe_load(f)
for api in apis: for api in apis:
prim_api = EagerPrimAPI(api) prim_api = EagerPrimAPI(api, api_prims)
if prim_api.is_prim_api: if prim_api.is_prim_api:
header_file.write(prim_api.gene_prim_api_declaration()) header_file.write(prim_api.gene_prim_api_declaration())
eager_prim_source_file.write(prim_api.gene_eager_prim_api_code()) eager_prim_source_file.write(prim_api.gene_eager_prim_api_code())
...@@ -112,16 +117,24 @@ def main(): ...@@ -112,16 +117,24 @@ def main():
default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc', default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc',
) )
parser.add_argument(
'--api_prim_yaml_path',
help='Primitive API list yaml file.',
default='paddle/fluid/prim/api/auto_code_generated/api.yaml',
)
options = parser.parse_args() options = parser.parse_args()
api_yaml_path = options.api_yaml_path api_yaml_path = options.api_yaml_path
prim_api_header_file_path = options.prim_api_header_path prim_api_header_file_path = options.prim_api_header_path
eager_prim_api_source_file_path = options.eager_prim_api_source_path eager_prim_api_source_file_path = options.eager_prim_api_source_path
api_prim_yaml_path = options.api_prim_yaml_path
generate_api( generate_api(
api_yaml_path, api_yaml_path,
prim_api_header_file_path, prim_api_header_file_path,
eager_prim_api_source_file_path, eager_prim_api_source_file_path,
api_prim_yaml_path,
) )
......
...@@ -39,12 +39,12 @@ inplace_optional_out_type_map = { ...@@ -39,12 +39,12 @@ inplace_optional_out_type_map = {
class BaseAPI: class BaseAPI:
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml, prims=tuple()):
# self.api = api_item_yaml['op'] # self.api = api_item_yaml['op']
self.api = api_item_yaml['name'] self.api = api_item_yaml['name']
self.is_prim_api = False self.is_prim_api = False
if api_item_yaml['name'] in white_ops_list: if api_item_yaml['name'] in prims:
self.is_prim_api = True self.is_prim_api = True
####################################### #######################################
...@@ -253,8 +253,8 @@ class BaseAPI: ...@@ -253,8 +253,8 @@ class BaseAPI:
class EagerPrimAPI(BaseAPI): class EagerPrimAPI(BaseAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml, prims=tuple()):
super().__init__(api_item_yaml) super().__init__(api_item_yaml, prims)
def get_api__func_name(self): def get_api__func_name(self):
api_func_name = self.api api_func_name = self.api
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import pathlib
import sys
import jinja2
import yaml
# fmt: off
# import from paddle/fluid/operators/generator
sys.path.append(
str(pathlib.Path(__file__).parents[3].joinpath('operators/generator'))
)
import filters as op_gen_filters
import generate_op as op_gen_utils
import parse_utils as op_gen_parse_utils
import tests as op_gen_tests
# fmt: on
def load_yaml(path, mode="rt"):
with open(path, mode) as f:
return yaml.safe_load(f)
def render(tpl, *args, **kwargs):
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(pathlib.Path(tpl).parent),
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=jinja2.StrictUndefined,
extensions=['jinja2.ext.do'],
)
env.filters.update(
{
'to_paddle_attr_type': op_gen_filters.to_paddle_attr_type,
'to_paddle_input_type': op_gen_filters.to_paddle_input_type,
'to_paddle_output_type': op_gen_filters.to_paddle_output_type,
'to_pascal': op_gen_filters.to_pascal_case,
"trip_intermediate": op_gen_filters.filter_intermediate,
}
)
env.tests.update(
{
'scalar': op_gen_tests.is_scalar,
'intarray': op_gen_tests.is_intarray,
'datatype': op_gen_tests.is_datatype,
'tensor_sequence': op_gen_tests.is_tensor_list,
}
)
return env.get_template(pathlib.Path(tpl).name).render(*args, **kwargs)
def filter_prim(apis, prims):
return [api for api in apis if api.get('name') in prims]
def extend_compat(apis, compats):
dicts = op_gen_parse_utils.to_named_dict(copy.deepcopy(apis))
for api in dicts.values():
op_gen_utils.restruct_io(api)
api['op_name'] = api['name']
op_gen_utils.add_fluid_name(api['inputs'])
op_gen_utils.add_fluid_name(api['attrs'])
op_gen_utils.add_fluid_name(api['outputs'])
api['backward'] = None
op_gen_utils.add_compat_name(compats, dicts, {})
return tuple(dicts.values())
def extend_version(apis, versions):
apis = copy.deepcopy(apis)
for api in apis:
for version in versions:
if version.get('op') == api.get('name'):
api['version'] = version['version']
return apis
def generate(
api_prim_yaml_path,
api_phi_yaml_path,
api_phi_legacy_yaml_path,
api_compat_yaml_path,
api_version_yaml_path,
template_path,
output_op_path,
):
prims, phis, legacy_phis, compats, versions = (
load_yaml(api_prim_yaml_path),
load_yaml(api_phi_yaml_path),
load_yaml(api_phi_legacy_yaml_path),
load_yaml(api_compat_yaml_path),
load_yaml(api_version_yaml_path),
)
apis = phis + legacy_phis
apis = filter_prim(apis, prims)
apis = extend_version(apis, versions)
apis = extend_compat(apis, compats)
if len(apis) > 0:
with open(output_op_path, "wt") as f:
msg = render(template_path, apis=apis)
f.write(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate Static Primitive API"
)
parser.add_argument(
'--api_prim_yaml_path', type=str, help="Primitive API yaml file.."
)
parser.add_argument(
'--api_phi_yaml_path', type=str, help="Parsed ops yaml file."
)
parser.add_argument(
'--api_phi_legacy_yaml_path', type=str, help="Parsed ops yaml file."
)
parser.add_argument(
'--api_compat_yaml_path', type=str, help="Ops args compat yaml file."
)
parser.add_argument(
'--api_version_yaml_path', type=str, help="Ops version yaml file."
)
parser.add_argument(
"--template_path", type=str, help="JinJa2 template file Path."
)
parser.add_argument("--output_path", type=str, help="Output path.")
args = parser.parse_args()
generate(
args.api_prim_yaml_path,
args.api_phi_yaml_path,
args.api_phi_legacy_yaml_path,
args.api_compat_yaml_path,
args.api_version_yaml_path,
args.template_path,
args.output_path,
)
{% from "utils.tpl" import static_prim_api %}
// Generated by /paddle/fluid/prim/api/auto_code_generated/static_gen.py.
// DO NOT EDIT!
#include <string.h>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include <algorithm>
#include <tuple>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace prim {
{% for api in apis %}
{{static_prim_api(api)}}
{% endfor %}
} // namespace prim
} // namespace paddle
{% macro static_prim_api(api) %}
{%- set fluid_name = api.op_name -%}
{%- set phi_name = api.name -%}
{%- set inputs = api.inputs -%}
{%- set outputs = api.outputs|trip_intermediate -%} {#- ignore intermediate output -#}
{%- set attrs = api.attrs -%}
{%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{{static_prim_api_sig(phi_name, inputs, outputs, attrs)}} {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("{{fluid_name}}");
{% filter indent(2, True) %}
{% for input in inputs %}
{{static_prim_api_input(input)}}
{% endfor %}
{% for output in outputs %}
{{static_prim_api_output(output)}}
{% endfor %}
{% for attr in attrs %}
{{static_prim_api_attr(attr)}}
{% endfor %}
{% endfilter %}
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
{% if outputs|length > 1 %}
return std::make_tuple{{sequence('(', ')', ', ', output_names)}};
{% elif outputs|length == 1 %}
return {{outputs[0].name}};
{% else %} {#- render nothing -#}
{% endif %}
}
{% endmacro %}
{%- macro static_prim_api_sig(name, inputs, outputs, attrs) -%}
template <>
{{static_prim_api_sig_ret(outputs)}} {{name}}<DescTensor>({{static_prim_api_sig_params(inputs, attrs)}})
{%- endmacro %}
{%- macro static_prim_api_sig_params(inputs, attrs) -%}
{%- set input_params = [] -%}
{%- for i in inputs -%} {%- do input_params.append(i.typename|to_paddle_input_type(i.optional)~' '~i.name) -%} {%- endfor -%}
{%- set attr_params = [] -%}
{%- for i in attrs -%} {%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name) -%} {%- endfor -%}
{{sequence('', '', ', ', input_params)}}
{%- if attr_params|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{{sequence('', '', ', ', attr_params)}}
{%- endmacro -%}
{%- macro static_prim_api_sig_ret(outputs) -%}
{%- set names = [] -%}
{%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type) -%} {%- endfor -%}
{%- if names|length > 1 -%}
std::tuple<{{sequence('', '', ', ', names)}}>
{%- else -%}
{{names[0]}}
{%- endif -%}
{%- endmacro -%}
{% macro static_prim_api_input(input) %}
{%- if input.optional -%}
{{static_prim_api_input_optional(input)}}
{%- else -%}
{{static_prim_api_input_without_optional(input)}}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_input_optional(input) -%}
{%- if input.typename=='Tensor[]' -%} {#- render the input of type paddle::optional<std::Vector<Tensor>> -#}
if ({{input.name}}) {
std::vector<std::string> {{input.name}}_names;
std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_names.begin(), [](const Tensor& t) {
return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name();
});
op->SetInput("{{input.fluid_name | to_pascal}}", {{input.name}}_names);
}
{%- else -%}
if ({{input.name}}) {
op->SetInput("{{input.fluid_name | to_pascal}}", {std::static_pointer_cast<prim::DescTensor>({{input.name}}->impl())->Name()});
}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_input_without_optional(input) -%}
{%- if input.typename is tensor_sequence -%} {#- render the input of type std::Vector<Tensor> -#}
std::vector<std::string> {{input.name}}_names;
std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_names.begin(), [](const Tensor& t) {
return std::static_pointer_cast<prim::DescTensor>(t.impl())->Name();
});
op->SetInput("{{input.fluid_name | to_pascal}}", {{input.name}}_names);
{%- else -%}
op->SetInput("{{input.fluid_name | to_pascal}}", {std::static_pointer_cast<prim::DescTensor>({{input.name}}.impl())->Name()});
{%- endif -%}
{%- endmacro -%}
{% macro static_prim_api_output(output) %}
{%- if output.optional -%}
{{static_prim_api_output_optional(output)}}
{%- else -%}
{{static_prim_api_output_without_optional(output)}}
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_output_without_optional(output) -%}
{%- if output.typename is tensor_sequence -%} {#- render the output of type std::Vector<Tensor> -#}
std::vector<Tensor> {{output.name}};
std::vector<std::string> {{output.name}}_names;
for (auto i=0; i<{{output.size}}; i++) {
auto tmp = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
{{output.name}}.push_back(tmp);
{{output.name}}_names.push_back(std::static_pointer_cast<prim::DescTensor>(tmp.impl())->Name());
}
op->SetOutput("{{output.fluid_name | to_pascal}}", {{output.name}}_names);
{%- else -%}
auto {{output.name}} = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
op->SetOutput("{{output.fluid_name | to_pascal}}", {std::static_pointer_cast<prim::DescTensor>({{output.name}}.impl())->Name()});
{%- endif -%}
{%- endmacro -%}
{%- macro static_prim_api_output_optional(output) -%}
// TODO(cxxly): Render optional output
{%- endmacro -%}
{% macro static_prim_api_attr(attr) %}
op->SetAttr("{{attr.fluid_name}}", {{phi_attr_to_fluid(attr)}});
{%- endmacro %}
{%- macro phi_attr_to_fluid(attr) -%}
{%- if attr.typename is intarray -%}
{{int_array_to_fluid(attr.name, attr.typename, attr.fluid_name, attr.data_type)}}
{%- elif attr.typename is scalar -%}
{{scalar_to_fluid(attr.name, attr.typename, attr.fluid_name, attr.data_type)}}
{%- elif attr.typename is datatype -%}
{{datatype_to_fluid(attr.name, attr.typename, attr.fluid_name, attr.data_type)}}
{%- else -%}
{{attr.name}}
{%- endif -%}
{%- endmacro %}
{%- macro int_array_to_fluid(src_name, src_type, dst_name, dst_type) -%}
{%- if dst_type=='std::vector<int>' -%}
unsafe_vector_cast<int64_t, int>({{src_name}}.GetData())
{%- else -%}
{{src_name}}.GetData()
{%- endif -%}
{%- endmacro -%}
{%- macro scalar_to_fluid(src_name, src_type, dst_name, dst_type) -%}
{{src_name}}.to<{{dst_type}}>()
{%- endmacro -%}
{%- macro datatype_to_fluid(src_name, src_type, dst_name, dst_type) -%}
paddle::framework::TransToProtoVarType({{src_name}})
{%- endmacro -%}
{%- macro sequence(lsymbol, rsymbol, delimiter, items) -%}
{{lsymbol}}{%- for item in items -%}{{item}}{{delimiter if not loop.last else "" }}{%- endfor -%}{{rsymbol}}
{%- endmacro -%}
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library( cc_library(
eager_prim_api generated_eager_prim_api
SRCS eager_prim_api.cc SRCS eager_prim_api.cc
DEPS final_dygraph_function eager_utils) DEPS final_dygraph_function eager_utils)
endif() endif()
cc_library(
generated_static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils)
add_subdirectory(utils) add_subdirectory(utils)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
manual_eager_prim_api
SRCS eager_prim_api.cc
DEPS final_dygraph_function eager_utils)
endif()
cc_library( cc_library(
static_prim_api manual_static_prim_api
SRCS static_prim_api.cc SRCS static_prim_api.cc
DEPS proto_desc static_utils) DEPS proto_desc static_utils)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
namespace paddle {
namespace prim {
template <>
Tensor reshape<Tensor>(const Tensor& x, const IntArray& shape) {
VLOG(4) << "Eager Prim API reshape_ad_func call";
return ::reshape_ad_func(x, shape);
}
template <>
Tensor full<Tensor>(const IntArray& shape,
const Scalar& value,
DataType dtype,
const Place& place) {
VLOG(4) << "Eager Prim API full_ad_func call";
return ::full_ad_func(shape, value, dtype, place);
}
} // namespace prim
} // namespace paddle
...@@ -14,14 +14,27 @@ ...@@ -14,14 +14,27 @@
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
// TODO(jiabin): Make this Header only for handwritten api, instead of include
// prim_generated_api.h
namespace paddle { namespace paddle {
namespace prim {} // namespace prim namespace prim {
using Tensor = paddle::experimental::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = paddle::experimental::DataType;
template <typename T>
Tensor reshape(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor full(const IntArray& shape,
const Scalar& value,
DataType dtype = DataType::FLOAT32,
const Place& place = CPUPlace());
} // namespace prim
} // namespace paddle } // namespace paddle
...@@ -38,111 +38,18 @@ namespace paddle { ...@@ -38,111 +38,18 @@ namespace paddle {
namespace prim { namespace prim {
template <> template <>
Tensor pow<DescTensor>(const Tensor& x, const Scalar& y) { Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("pow");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("factor", y.to<float>());
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor scale<DescTensor>(const Tensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("scale", scale.to<float>());
op->SetAttr("bias", bias);
op->SetAttr("bias_after_scale", bias_after_scale);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("elementwise_mul");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor unsqueeze<DescTensor>(const Tensor& x, const IntArray& axis) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("unsqueeze2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(axis.GetData().begin(), axis.GetData().end());
op->SetAttr("axes", new_shape);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor expand<DescTensor>(const Tensor& x, const IntArray& shape) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("expand_v2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
op->SetAttr("shape", new_shape);
op->CheckAttrs();
op->InferVarType(block);
return out;
}
template <>
Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("elementwise_div"); // TODO(cxxly): Fix test_resnet_prim_cinn error when SetType("reshape2")
op->SetType("reshape");
op->SetInput("X", op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()}); {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y", // Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()}); auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("shape", unsafe_vector_cast<int64_t, int>(shape.GetData()));
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
...@@ -186,70 +93,5 @@ Tensor full<DescTensor>(const IntArray& shape, ...@@ -186,70 +93,5 @@ Tensor full<DescTensor>(const IntArray& shape,
return out; return out;
} }
template <>
Tensor sum<DescTensor>(const Tensor& x,
const IntArray& axis,
DataType dtype,
bool keepdim) {
// Grad infershape
Tensor out = empty<DescTensor>({}, dtype, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("reduce_sum");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
std::vector<int> res;
for (auto value : axis.GetData()) {
res.push_back(static_cast<int>(value));
}
op->SetAttr("dim", res);
op->SetAttr("keep_dim", keepdim);
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
// Grad infershape
Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("reshape");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
std::vector<int> res;
for (auto value : shape.GetData()) {
// TODO(jiabin): This cast is not safe for now, find a way to handle this.
res.push_back(static_cast<int>(value));
}
op->SetAttr("shape", res);
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor exp<DescTensor>(const Tensor& x) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("exp");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -78,5 +78,11 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, ...@@ -78,5 +78,11 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
return get_reduce_dims_from_out(out_dims, x_dims); return get_reduce_dims_from_out(out_dims, x_dims);
} }
// TODO(cxxly): Check and throws InvalidCastException when overflow.
template <typename SRC_T, typename DST_T>
static std::vector<DST_T> unsafe_vector_cast(const std::vector<SRC_T>& src) {
std::vector<DST_T> dst(src.begin(), src.end());
return dst;
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -231,6 +231,16 @@ ...@@ -231,6 +231,16 @@
- op : concat - op : concat
backward : concat_grad backward : concat_grad
inputs:
x: X
outputs:
out: Out
attrs:
axis: axis
scalar :
axis :
data_type : int
tensor_name : AxisTensor
extra : extra :
attrs : [bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32"] attrs : [bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32"]
...@@ -395,6 +405,10 @@ ...@@ -395,6 +405,10 @@
- op : divide (elementwise_div) - op : divide (elementwise_div)
backward : divide_grad (elementwise_div) backward : divide_grad (elementwise_div)
inputs :
{x: X, y : Y}
outputs :
out: Out
extra : extra :
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32", attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f] bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
...@@ -486,6 +500,17 @@ ...@@ -486,6 +500,17 @@
- op : expand (expand_v2) - op : expand (expand_v2)
backward : expand_grad (expand_v2_grad) backward : expand_grad (expand_v2_grad)
inputs :
x : X
attrs :
shape : shape
outputs :
out : Out
int_array:
shape :
data_type : int
tensor_name : Shape
tensors_name : expand_shapes_tensor
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
...@@ -898,6 +923,12 @@ ...@@ -898,6 +923,12 @@
- op : matmul (matmul_v2) - op : matmul (matmul_v2)
backward : matmul_grad (matmul_v2_grad) backward : matmul_grad (matmul_v2_grad)
inputs :
{x : X, y : Y}
attrs :
{transpose_x : trans_x, transpose_y : trans_y}
outputs :
out : Out
extra : extra :
attrs : [bool use_mkldnn = false, 'int[] fused_reshape_Out = {}', 'int[] fused_transpose_Out = {}', attrs : [bool use_mkldnn = false, 'int[] fused_reshape_Out = {}', 'int[] fused_transpose_Out = {}',
str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}', str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}',
...@@ -915,6 +946,20 @@ ...@@ -915,6 +946,20 @@
outputs : outputs :
out : Out out : Out
- op : max (reduce_max)
backward : max_grad (reduce_max_grad)
inputs:
x : X
attrs:
{ axis : dim, keepdim : keep_dim}
outputs:
out : Out
int_array:
axis :
data_type : int
extra :
attrs : [bool use_mkldnn = false]
- op : maximum (elementwise_max) - op : maximum (elementwise_max)
backward : maximum_grad (elementwise_max_grad) backward : maximum_grad (elementwise_max_grad)
extra : extra :
...@@ -981,6 +1026,10 @@ ...@@ -981,6 +1026,10 @@
- op : multiply (elementwise_mul) - op : multiply (elementwise_mul)
backward : multiply_grad (elementwise_mul_grad) backward : multiply_grad (elementwise_mul_grad)
inputs :
{x : X, y : Y}
outputs :
out : Out
extra : extra :
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32", attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f] bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
...@@ -1079,6 +1128,20 @@ ...@@ -1079,6 +1128,20 @@
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- op : prod (reduce_prod)
backward : prod_grad (reduce_prod_grad)
inputs:
x : X
attrs:
{ dims : dim, keep_dim : keep_dim}
outputs:
out : Out
int_array:
axis :
data_type : int
extra :
attrs : [bool use_mkldnn = false]
- op : put_along_axis - op : put_along_axis
backward : put_along_axis_grad backward : put_along_axis_grad
inputs : inputs :
...@@ -1133,11 +1196,6 @@ ...@@ -1133,11 +1196,6 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- op : reduce_max
backward : reduce_max_grad
extra :
attrs : [bool use_mkldnn = false]
- op : reduce_mean - op : reduce_mean
backward : reduce_mean_grad backward : reduce_mean_grad
extra : extra :
...@@ -1148,16 +1206,6 @@ ...@@ -1148,16 +1206,6 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- op : reduce_prod
backward : reduce_prod_grad
extra :
attrs : [bool use_mkldnn = false]
- op : reduce_sum
backward : reduce_sum_grad
extra :
attrs : [bool use_mkldnn = false]
- op : relu - op : relu
backward : relu_grad, relu_double_grad (relu_grad_grad) backward : relu_grad, relu_double_grad (relu_grad_grad)
inputs : inputs :
...@@ -1186,6 +1234,20 @@ ...@@ -1186,6 +1234,20 @@
extra : extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : reshape (reshape2)
backward : reshape_grad (reshape2_grad)
inputs:
x : X
outputs:
out : Out
int_array:
shape :
data_type : int
tensor_name : Shape
tensors_name : ShapeTensor
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool use_quantizer = false]
- op : roll - op : roll
backward : roll_grad backward : roll_grad
inputs : inputs :
...@@ -1216,6 +1278,10 @@ ...@@ -1216,6 +1278,10 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : scale - op : scale
inputs :
x : X
outputs :
out: Out
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
...@@ -1437,10 +1503,28 @@ ...@@ -1437,10 +1503,28 @@
- op : subtract (elementwise_sub) - op : subtract (elementwise_sub)
backward : subtract_grad (elementwise_sub_grad) backward : subtract_grad (elementwise_sub_grad)
inputs :
{x : X, y: Y}
outputs :
out : Out
extra : extra :
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32", attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f] bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : sum (reduce_sum)
backward : (sum_grad) reduce_sum_grad
inputs:
{x : X}
attrs:
{ axis : dim, keepdim : keep_dim, dtype : out_dtype}
outputs:
out : Out
int_array:
axis :
data_type : int
extra :
attrs : [bool use_mkldnn = false]
- op : svd - op : svd
backward : svd_grad backward : svd_grad
inputs : inputs :
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册