未验证 提交 638aab6e 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Support inplace and intermediate in C++ API (#39651)

* support inplace and intermediate in yaml

* add cmake for dygraph_api
上级 70b9f2ac
...@@ -2,16 +2,17 @@ paddle/fluid/operators/distributed/send_recv.proto ...@@ -2,16 +2,17 @@ paddle/fluid/operators/distributed/send_recv.proto
paddle/fluid/API.spec paddle/fluid/API.spec
paddle/fluid/API_DEV.spec paddle/fluid/API_DEV.spec
paddle/fluid/API_PR.spec paddle/fluid/API_PR.spec
paddle/fluid/eager/api/generated/*
paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/pten/api/backward/backward_api.h
paddle/pten/api/include/api.h paddle/pten/api/include/api.h
paddle/pten/api/lib/api.cc paddle/pten/api/lib/api.cc
paddle/pten/api/backward/backward_api.h paddle/pten/api/lib/dygraph_api.*
paddle/pten/api/lib/backward_api.cc paddle/pten/api/lib/backward_api.cc
paddle/pten/extension.h
paddle/pten/include/* paddle/pten/include/*
paddle/pten/infermeta/generated.* paddle/pten/infermeta/generated.*
paddle/pten/extension.h
paddle/fluid/eager/api/generated/*
*.DS_Store *.DS_Store
*.vs *.vs
......
...@@ -17,8 +17,12 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py) ...@@ -17,8 +17,12 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml) set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/include/api.h) set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/include/api.h)
set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/api.cc) set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/api.cc)
set(dygraph_api_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/dygraph_api.h)
set(dygraph_api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/dygraph_api.cc)
set(api_header_file_tmp ${api_header_file}.tmp) set(api_header_file_tmp ${api_header_file}.tmp)
set(api_source_file_tmp ${api_source_file}.tmp) set(api_source_file_tmp ${api_source_file}.tmp)
set(dygraph_api_header_file_tmp ${dygraph_api_header_file}.tmp)
set(dygraph_api_source_file_tmp ${dygraph_api_source_file}.tmp)
# backward api file # backward api file
set(bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/backward_api_gen.py) set(bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/backward_api_gen.py)
...@@ -40,14 +44,18 @@ endif() ...@@ -40,14 +44,18 @@ endif()
# generate forward api # generate forward api
add_custom_command( add_custom_command(
OUTPUT ${api_header_file} ${api_source_file} OUTPUT ${api_header_file} ${api_source_file} ${dygraph_api_header_file} ${dygraph_api_source_file}
COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml
COMMAND ${PYTHON_EXECUTABLE} ${api_gen_file} COMMAND ${PYTHON_EXECUTABLE} ${api_gen_file}
--api_yaml_path ${api_yaml_file} --api_yaml_path ${api_yaml_file}
--api_header_path ${api_header_file_tmp} --api_header_path ${api_header_file_tmp}
--api_source_path ${api_source_file_tmp} --api_source_path ${api_source_file_tmp}
--dygraph_api_header_path ${dygraph_api_header_file_tmp}
--dygraph_api_source_path ${dygraph_api_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_header_file_tmp} ${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp} ${dygraph_api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}" COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base} DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base}
VERBATIM) VERBATIM)
...@@ -86,5 +94,6 @@ cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw) ...@@ -86,5 +94,6 @@ cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_dygraph_api SRCS ${dygraph_api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api) cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten) cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten)
...@@ -72,11 +72,14 @@ inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) { ...@@ -72,11 +72,14 @@ inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) {
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
if (!out->initialized()) {
auto dense_tensor = std::make_shared<pten::DenseTensor>( auto dense_tensor = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<SharedStorage>(pten::TransToPtenPlace(backend)), pten::make_intrusive<SharedStorage>(pten::TransToPtenPlace(backend)),
pten::DenseTensorMeta()); pten::DenseTensorMeta());
out->set_impl(dense_tensor); out->set_impl(dense_tensor);
return dense_tensor.get(); return dense_tensor.get();
}
return static_cast<pten::DenseTensor*>(out->impl().get());
} }
inline std::vector<pten::DenseTensor*> SetKernelOutput( inline std::vector<pten::DenseTensor*> SetKernelOutput(
...@@ -96,9 +99,12 @@ inline std::vector<pten::DenseTensor*> SetKernelOutput( ...@@ -96,9 +99,12 @@ inline std::vector<pten::DenseTensor*> SetKernelOutput(
inline pten::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, inline pten::SelectedRows* SetSelectedRowsKernelOutput(Backend backend,
Tensor* out) { Tensor* out) {
if (!out->initialized()) {
auto select_rows = std::make_shared<pten::SelectedRows>(); auto select_rows = std::make_shared<pten::SelectedRows>();
out->set_impl(select_rows); out->set_impl(select_rows);
return select_rows.get(); return select_rows.get();
}
return static_cast<pten::SelectedRows*>(out->impl().get());
} }
} // namespace experimental } // namespace experimental
......
...@@ -249,10 +249,13 @@ Tensor::data<pten::dtype::bfloat16>() const; ...@@ -249,10 +249,13 @@ Tensor::data<pten::dtype::bfloat16>() const;
template <typename T> template <typename T>
T *Tensor::data() { T *Tensor::data() {
PADDLE_THROW(pten::errors::Unimplemented( if (is_dense_tensor()) {
"It is not currently supported to directly obtain the modifiable data " return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->data<T>();
"address through the tensor::data<T>() method, please use the " } else if (pten::SelectedRows::classof(impl_.get())) {
"tensor::mutable_data<T>() method.")); return std::dynamic_pointer_cast<pten::SelectedRows>(impl_)
->mutable_value()
->data<T>();
}
return nullptr; return nullptr;
} }
......
...@@ -67,6 +67,25 @@ TEST(API, reshape) { ...@@ -67,6 +67,25 @@ TEST(API, reshape) {
ASSERT_EQ(value_equal, true); ASSERT_EQ(value_equal, true);
} }
TEST(API, reshape_) {
// 1. create tensor
auto x = paddle::experimental::full(
{3, 2, 2, 3}, 1.0, experimental::DataType::FLOAT32);
// 2. test API
paddle::experimental::Tensor out = paddle::experimental::reshape_(x, {12, 3});
// 3. check result
std::vector<int64_t> expect_shape = {12, 3};
ASSERT_EQ(out.shape()[0], expect_shape[0]);
ASSERT_EQ(out.shape()[1], expect_shape[1]);
ASSERT_EQ(out.numel(), 36);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
ASSERT_EQ(out.data<float>(), x.data<float>());
}
TEST(Tensor, old_reshape) { TEST(Tensor, old_reshape) {
paddle::experimental::Tensor x(paddle::PlaceType::kCPU); paddle::experimental::Tensor x(paddle::PlaceType::kCPU);
x.reshape({3, 4}); x.reshape({3, 4});
......
...@@ -62,7 +62,7 @@ TEST(API, scale_sr) { ...@@ -62,7 +62,7 @@ TEST(API, scale_sr) {
experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32).impl()); experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32).impl());
*(selected_rows->mutable_value()) = *dense_tensor; *(selected_rows->mutable_value()) = *dense_tensor;
experimental::Tensor x(selected_rows); experimental::Tensor x(selected_rows);
const auto out = experimental::scale(x, 2.0, 1.0, true); auto out = experimental::scale(x, 2.0, 1.0, true);
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3); ASSERT_EQ(out.dims()[0], 3);
......
...@@ -142,11 +142,12 @@ ...@@ -142,11 +142,12 @@
- api : reshape - api : reshape
args : (Tensor x, ScalarArray shape) args : (Tensor x, ScalarArray shape)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ReshapeInferMeta func : ReshapeInferMeta
kernel : kernel :
func : reshape func : reshape
inplace : (x -> out)
- api : scale - api : scale
args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) args : (Tensor x, Scalar scale, float bias, bool bias_after_scale)
......
...@@ -48,10 +48,14 @@ class BaseAPI(object): ...@@ -48,10 +48,14 @@ class BaseAPI(object):
self.support_selected_rows_kernel = False if len(self.kernel[ self.support_selected_rows_kernel = False if len(self.kernel[
'func']) == 1 else True 'func']) == 1 else True
self.data_transform = self.parse_data_transform(api_item_yaml) self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map = self.parse_inplace(api_item_yaml)
def get_api_name(self, api_item_yaml): def get_api_name(self, api_item_yaml):
return api_item_yaml['api'] return api_item_yaml['api']
def get_api_func_name(self):
return self.api
def parse_args(self, api_name, api_item_yaml): def parse_args(self, api_name, api_item_yaml):
inputs, attrs, args_str = self.parse_input_and_attr( inputs, attrs, args_str = self.parse_input_and_attr(
api_name, api_item_yaml['args']) api_name, api_item_yaml['args'])
...@@ -225,13 +229,37 @@ class BaseAPI(object): ...@@ -225,13 +229,37 @@ class BaseAPI(object):
return data_transform return data_transform
def parse_inplace(self, api_item_yaml):
if 'inplace' in api_item_yaml:
inplace_map = {}
inplace_list = api_item_yaml['inplace'].split(',')
for item in inplace_list:
result = re.search(r"(?P<in>\w+)\s*->\s(?P<out>\w+)", item)
in_val = result.group('in')
out_val = result.group('out')
assert in_val in self.inputs['names'], \
f"{self.api} : Inplace input error: the input var name('{in_val}') is not found in the input args of {self.api}."
assert out_val in self.outputs['names'], \
f"{self.api} : Inplace output error: the output var name('{out_val}') is not found in the output args of {self.api}."
inplace_map[out_val] = in_val
return inplace_map
else:
return None
# Override by child class # Override by child class
def get_return_type(self, out_type_list): def get_return_type(self, out_type_list):
return None return None
def gene_api_declaration(self): def gene_api_declaration(self):
api_declaration = f""" api_declaration = f"""
PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare']}); PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']});
"""
if self.is_base_api and self.inplace_map is not None:
api_declaration = api_declaration + f"""
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.args_str['args_declare']});
""" """
return api_declaration return api_declaration
...@@ -527,14 +555,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare ...@@ -527,14 +555,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
return input_tensor_code, kernel_args[:-2], kernel_signature return input_tensor_code, kernel_args[:-2], kernel_signature
# Override by child class # Override by child class
def gene_output(self, output_type_list, set_out_func, code_indent): def gene_output(self,
output_type_list,
set_out_func,
code_indent,
inplace_flag=False):
return None, None, None return None, None, None
def gen_dense_tensor_kernel_code(self, code_indent): def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False):
input_tensors, kernel_args, kernel_signature = self.get_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
code_indent) code_indent)
outputs_args, kernel_output_names, output_create = self.gene_output( outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetKernelOutput', code_indent) self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag)
return f""" return f"""
{code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); {code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
...@@ -552,11 +584,12 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare ...@@ -552,11 +584,12 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
{code_indent} return out;""" {code_indent} return out;"""
def gen_selected_rows_kernel_code(self, code_indent): def gen_selected_rows_kernel_code(self, code_indent, inplace_flag=False):
input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args(
code_indent) code_indent)
outputs_args, kernel_output_names, output_create = self.gene_output( outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent) self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent,
inplace_flag)
return f""" return f"""
{code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][1]}", {{kernel_backend, kernel_layout, kernel_data_type}}); {code_indent} "{self.kernel['func'][1]}", {{kernel_backend, kernel_layout, kernel_data_type}});
...@@ -574,32 +607,38 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare ...@@ -574,32 +607,38 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
{code_indent} return out;""" {code_indent} return out;"""
def gene_api_code(self): def gene_base_api_code(self, inplace_flag=False):
if self.is_base_api: api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
api_code = f""" api_code = f"""
PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str["args_define"]}) {{ PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
""" """
if self.support_selected_rows_kernel: if self.support_selected_rows_kernel:
code_indent = ' ' code_indent = ' '
api_code = api_code + f""" return api_code + f"""
if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{ if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{
{self.gen_dense_tensor_kernel_code(code_indent)} {self.gen_dense_tensor_kernel_code(code_indent, inplace_flag)}
}} else {{ }} else {{
{self.gen_selected_rows_kernel_code(code_indent)} {self.gen_selected_rows_kernel_code(code_indent, inplace_flag)}
}} }}
}} }}
""" """
return api_code
else: else:
code_indent = '' code_indent = ''
return api_code + self.gen_dense_tensor_kernel_code( return api_code + self.gen_dense_tensor_kernel_code(
code_indent) + """ code_indent, inplace_flag) + """
} }
""" """
def gene_api_code(self):
if self.is_base_api:
api_code = self.gene_base_api_code()
if self.inplace_map is not None:
api_code = api_code + self.gene_base_api_code(inplace_flag=True)
return api_code
else: else:
inveke_func_name = self.invoke.split('(')[0].strip() inveke_func_name = self.invoke.split('(')[0].strip()
if inveke_func_name in self.attrs['names']: if inveke_func_name in self.attrs['names']:
......
...@@ -15,22 +15,38 @@ ...@@ -15,22 +15,38 @@
import os import os
import yaml import yaml
import argparse import argparse
import re
from api_base import BaseAPI from api_base import BaseAPI
class ForwardAPI(BaseAPI): class ForwardAPI(BaseAPI):
prefix_tensor_name = 'dense_'
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
super(ForwardAPI, self).__init__(api_item_yaml) super(ForwardAPI, self).__init__(api_item_yaml)
self.is_dygraph_api = self.parse_intermediate(api_item_yaml)
def get_api_func_name(self):
if self.is_dygraph_api:
return self.api + '_intermediate'
else:
return self.api
def parse_intermediate(self, api_item_yaml):
if 'intermediate' in api_item_yaml:
return True
else:
return False
def get_return_type(self, out_type_list): def get_return_type(self, out_type_list):
return out_type_list[0] if len( return out_type_list[0] if len(
out_type_list) == 1 else "std::tuple<" + ",".join( out_type_list) == 1 else "std::tuple<" + ",".join(
out_type_list) + ">" out_type_list) + ">"
def gene_output(self, output_type_list, set_out_func, code_indent): def gene_output(self,
output_type_list,
set_out_func,
code_indent,
inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
...@@ -38,8 +54,11 @@ class ForwardAPI(BaseAPI): ...@@ -38,8 +54,11 @@ class ForwardAPI(BaseAPI):
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} out; {code_indent} {self.outputs['return_type']} out{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
...@@ -49,6 +68,11 @@ class ForwardAPI(BaseAPI): ...@@ -49,6 +68,11 @@ class ForwardAPI(BaseAPI):
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));"""
...@@ -110,12 +134,15 @@ namespace experimental { ...@@ -110,12 +134,15 @@ namespace experimental {
""") """)
def generate_api(api_yaml_path, header_file_path, source_file_path): def generate_api(api_yaml_path, header_file_path, source_file_path,
dygraph_header_file_path, dygraph_source_file_path):
with open(api_yaml_path, 'r') as f: with open(api_yaml_path, 'r') as f:
apis = yaml.load(f, Loader=yaml.FullLoader) apis = yaml.load(f, Loader=yaml.FullLoader)
header_file = open(header_file_path, 'w') header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w') source_file = open(source_file_path, 'w')
dygraph_header_file = open(dygraph_header_file_path, 'w')
dygraph_source_file = open(dygraph_source_file_path, 'w')
namespace = api_namespace() namespace = api_namespace()
...@@ -127,20 +154,37 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -127,20 +154,37 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file)) source_file.write(source_include(include_header_file))
source_file.write(namespace[0]) source_file.write(namespace[0])
dygraph_header_file.write("#pragma once\n")
dygraph_header_file.write(header_include())
dygraph_header_file.write(namespace[0])
dygraph_include_header_file = "paddle/pten/api/lib/dygraph_api.h"
dygraph_source_file.write(source_include(dygraph_include_header_file))
dygraph_source_file.write(namespace[0])
for api in apis: for api in apis:
api_code = ForwardAPI(api) foward_api = ForwardAPI(api)
print(api_code.gene_api_declaration()) if foward_api.is_dygraph_api:
header_file.write(api_code.gene_api_declaration()) dygraph_header_file.write(foward_api.gene_api_declaration())
source_file.write(api_code.gene_api_code()) dygraph_source_file.write(foward_api.gene_api_code())
else:
header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
header_file.write(namespace[1]) header_file.write(namespace[1])
source_file.write(namespace[1]) source_file.write(namespace[1])
dygraph_header_file.write(namespace[1])
dygraph_source_file.write(namespace[1])
source_file.write(api_register()) source_file.write(api_register())
header_file.close() header_file.close()
source_file.close() source_file.close()
dygraph_header_file.close()
dygraph_source_file.close()
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -149,6 +193,7 @@ def main(): ...@@ -149,6 +193,7 @@ def main():
'--api_yaml_path', '--api_yaml_path',
help='path to api yaml file', help='path to api yaml file',
default='python/paddle/utils/code_gen/api.yaml') default='python/paddle/utils/code_gen/api.yaml')
parser.add_argument( parser.add_argument(
'--api_header_path', '--api_header_path',
help='output of generated api header code file', help='output of generated api header code file',
...@@ -159,13 +204,26 @@ def main(): ...@@ -159,13 +204,26 @@ def main():
help='output of generated api source code file', help='output of generated api source code file',
default='paddle/pten/api/lib/api.cc') default='paddle/pten/api/lib/api.cc')
parser.add_argument(
'--dygraph_api_header_path',
help='output of generated dygraph api header code file',
default='paddle/pten/api/lib/dygraph_api.h')
parser.add_argument(
'--dygraph_api_source_path',
help='output of generated dygraph api source code file',
default='paddle/pten/api/lib/dygraph_api.cc')
options = parser.parse_args() options = parser.parse_args()
api_yaml_path = options.api_yaml_path api_yaml_path = options.api_yaml_path
header_file_path = options.api_header_path header_file_path = options.api_header_path
source_file_path = options.api_source_path source_file_path = options.api_source_path
dygraph_header_file_path = options.dygraph_api_header_path
dygraph_source_file_path = options.dygraph_api_source_path
generate_api(api_yaml_path, header_file_path, source_file_path) generate_api(api_yaml_path, header_file_path, source_file_path,
dygraph_header_file_path, dygraph_source_file_path)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -69,7 +69,11 @@ class BackwardAPI(BaseAPI): ...@@ -69,7 +69,11 @@ class BackwardAPI(BaseAPI):
return out_type_list[0] if len( return out_type_list[0] if len(
out_type_list) == 1 else "std::vector<std::vector<Tensor>>" out_type_list) == 1 else "std::vector<std::vector<Tensor>>"
def gene_output(self, output_type_list, set_out_func, code_indent): def gene_output(self,
output_type_list,
set_out_func,
code_indent,
inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
...@@ -77,8 +81,11 @@ class BackwardAPI(BaseAPI): ...@@ -77,8 +81,11 @@ class BackwardAPI(BaseAPI):
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} out; {code_indent} {self.outputs['return_type']} out{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
...@@ -90,11 +97,22 @@ class BackwardAPI(BaseAPI): ...@@ -90,11 +97,22 @@ class BackwardAPI(BaseAPI):
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if out_type_item == 'Tensor': if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]' get_out_code = f'&out[{i}][0]'
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} out[{i}].emplace_back();""" {code_indent} out[{i}].emplace_back();"""
else: else:
get_out_code = f'&out[{i}]' get_out_code = f'&out[{i}]'
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} out[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import yaml import yaml
import argparse import argparse
from api_base import BaseAPI from api_gen import ForwardAPI
def get_wrapped_infermeta_name(api_name): def get_wrapped_infermeta_name(api_name):
...@@ -24,7 +24,7 @@ def get_wrapped_infermeta_name(api_name): ...@@ -24,7 +24,7 @@ def get_wrapped_infermeta_name(api_name):
def gene_wrapped_infermeta_and_register(api): def gene_wrapped_infermeta_and_register(api):
if api.is_base_api: if api.is_base_api and not api.is_dygraph_api:
register_code = f""" register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});"""
...@@ -76,20 +76,6 @@ PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_ ...@@ -76,20 +76,6 @@ PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_
return '', '', '' return '', '', ''
def gene_infermeta_register(api):
if api.is_base_api:
if api.infer_meta['param'] is None:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});"""
else:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_name(api.kernel['func'][0])});"""
else:
return ''
def header_include(): def header_include():
return """ return """
#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/core/meta_tensor.h"
...@@ -138,7 +124,7 @@ def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path, ...@@ -138,7 +124,7 @@ def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
infermeta_register_code = '' infermeta_register_code = ''
for api in apis: for api in apis:
api_item = BaseAPI(api) api_item = ForwardAPI(api)
declare_code, defind_code, register_code = gene_wrapped_infermeta_and_register( declare_code, defind_code, register_code = gene_wrapped_infermeta_and_register(
api_item) api_item)
header_file.write(declare_code) header_file.write(declare_code)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册