From 638aab6e5c2be347e22404788e3157b5c1e7e79b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 18 Feb 2022 11:08:29 +0800 Subject: [PATCH] [Pten] Support inplace and intermediate in C++ API (#39651) * support inplace and intermediate in yaml * add cmake for dygraph_api --- .gitignore | 7 +- paddle/pten/api/lib/CMakeLists.txt | 11 ++- paddle/pten/api/lib/api_utils.h | 22 ++++-- paddle/pten/api/lib/tensor.cc | 11 ++- paddle/pten/tests/api/test_reshape_api.cc | 19 +++++ paddle/pten/tests/api/test_scale_api.cc | 2 +- python/paddle/utils/code_gen/api.yaml | 3 +- python/paddle/utils/code_gen/api_base.py | 79 ++++++++++++++----- python/paddle/utils/code_gen/api_gen.py | 78 +++++++++++++++--- .../paddle/utils/code_gen/backward_api_gen.py | 24 +++++- .../utils/code_gen/wrapped_infermeta_gen.py | 20 +---- 11 files changed, 208 insertions(+), 68 deletions(-) diff --git a/.gitignore b/.gitignore index ce0cd3bc27..13f6a427ac 100644 --- a/.gitignore +++ b/.gitignore @@ -2,16 +2,17 @@ paddle/fluid/operators/distributed/send_recv.proto paddle/fluid/API.spec paddle/fluid/API_DEV.spec paddle/fluid/API_PR.spec +paddle/fluid/eager/api/generated/* paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_PR.spec +paddle/pten/api/backward/backward_api.h paddle/pten/api/include/api.h paddle/pten/api/lib/api.cc -paddle/pten/api/backward/backward_api.h +paddle/pten/api/lib/dygraph_api.* paddle/pten/api/lib/backward_api.cc +paddle/pten/extension.h paddle/pten/include/* paddle/pten/infermeta/generated.* -paddle/pten/extension.h -paddle/fluid/eager/api/generated/* *.DS_Store *.vs diff --git a/paddle/pten/api/lib/CMakeLists.txt b/paddle/pten/api/lib/CMakeLists.txt index 969ac51751..3ce99b213e 100644 --- a/paddle/pten/api/lib/CMakeLists.txt +++ b/paddle/pten/api/lib/CMakeLists.txt @@ -17,8 +17,12 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py) set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml) set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/include/api.h) set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/api.cc) +set(dygraph_api_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/dygraph_api.h) +set(dygraph_api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/dygraph_api.cc) set(api_header_file_tmp ${api_header_file}.tmp) set(api_source_file_tmp ${api_source_file}.tmp) +set(dygraph_api_header_file_tmp ${dygraph_api_header_file}.tmp) +set(dygraph_api_source_file_tmp ${dygraph_api_source_file}.tmp) # backward api file set(bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/backward_api_gen.py) @@ -40,14 +44,18 @@ endif() # generate forward api add_custom_command( - OUTPUT ${api_header_file} ${api_source_file} + OUTPUT ${api_header_file} ${api_source_file} ${dygraph_api_header_file} ${dygraph_api_source_file} COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml COMMAND ${PYTHON_EXECUTABLE} ${api_gen_file} --api_yaml_path ${api_yaml_file} --api_header_path ${api_header_file_tmp} --api_source_path ${api_source_file_tmp} + --dygraph_api_header_path ${dygraph_api_header_file_tmp} + --dygraph_api_source_path ${dygraph_api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_header_file_tmp} ${dygraph_api_header_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp} ${dygraph_api_source_file} COMMENT "copy_if_different ${api_header_file} ${api_source_file}" DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base} VERBATIM) @@ -86,5 +94,6 @@ cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw) cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) +cc_library(pten_dygraph_api SRCS ${dygraph_api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api) cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten) diff --git a/paddle/pten/api/lib/api_utils.h b/paddle/pten/api/lib/api_utils.h index 1df3b5964f..42c940975c 100644 --- a/paddle/pten/api/lib/api_utils.h +++ b/paddle/pten/api/lib/api_utils.h @@ -72,11 +72,14 @@ inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) { /* ------------------ for output ----------------------- */ inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { - auto dense_tensor = std::make_shared( - pten::make_intrusive(pten::TransToPtenPlace(backend)), - pten::DenseTensorMeta()); - out->set_impl(dense_tensor); - return dense_tensor.get(); + if (!out->initialized()) { + auto dense_tensor = std::make_shared( + pten::make_intrusive(pten::TransToPtenPlace(backend)), + pten::DenseTensorMeta()); + out->set_impl(dense_tensor); + return dense_tensor.get(); + } + return static_cast(out->impl().get()); } inline std::vector SetKernelOutput( @@ -96,9 +99,12 @@ inline std::vector SetKernelOutput( inline pten::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out) { - auto select_rows = std::make_shared(); - out->set_impl(select_rows); - return select_rows.get(); + if (!out->initialized()) { + auto select_rows = std::make_shared(); + out->set_impl(select_rows); + return select_rows.get(); + } + return static_cast(out->impl().get()); } } // namespace experimental diff --git a/paddle/pten/api/lib/tensor.cc b/paddle/pten/api/lib/tensor.cc index cc68b19bf8..0d96ec99b9 100644 --- a/paddle/pten/api/lib/tensor.cc +++ b/paddle/pten/api/lib/tensor.cc @@ -249,10 +249,13 @@ Tensor::data() const; template T *Tensor::data() { - PADDLE_THROW(pten::errors::Unimplemented( - "It is not currently supported to directly obtain the modifiable data " - "address through the tensor::data() method, please use the " - "tensor::mutable_data() method.")); + if (is_dense_tensor()) { + return std::dynamic_pointer_cast(impl_)->data(); + } else if (pten::SelectedRows::classof(impl_.get())) { + return std::dynamic_pointer_cast(impl_) + ->mutable_value() + ->data(); + } return nullptr; } diff --git a/paddle/pten/tests/api/test_reshape_api.cc b/paddle/pten/tests/api/test_reshape_api.cc index 27e47a9183..1a27e3142d 100644 --- a/paddle/pten/tests/api/test_reshape_api.cc +++ b/paddle/pten/tests/api/test_reshape_api.cc @@ -67,6 +67,25 @@ TEST(API, reshape) { ASSERT_EQ(value_equal, true); } +TEST(API, reshape_) { + // 1. create tensor + auto x = paddle::experimental::full( + {3, 2, 2, 3}, 1.0, experimental::DataType::FLOAT32); + + // 2. test API + paddle::experimental::Tensor out = paddle::experimental::reshape_(x, {12, 3}); + // 3. check result + std::vector expect_shape = {12, 3}; + ASSERT_EQ(out.shape()[0], expect_shape[0]); + ASSERT_EQ(out.shape()[1], expect_shape[1]); + ASSERT_EQ(out.numel(), 36); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + ASSERT_EQ(out.data(), x.data()); +} + TEST(Tensor, old_reshape) { paddle::experimental::Tensor x(paddle::PlaceType::kCPU); x.reshape({3, 4}); diff --git a/paddle/pten/tests/api/test_scale_api.cc b/paddle/pten/tests/api/test_scale_api.cc index 77c5f1b44f..f7a2a72d15 100644 --- a/paddle/pten/tests/api/test_scale_api.cc +++ b/paddle/pten/tests/api/test_scale_api.cc @@ -62,7 +62,7 @@ TEST(API, scale_sr) { experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32).impl()); *(selected_rows->mutable_value()) = *dense_tensor; experimental::Tensor x(selected_rows); - const auto out = experimental::scale(x, 2.0, 1.0, true); + auto out = experimental::scale(x, 2.0, 1.0, true); ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims()[0], 3); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 66411d00f1..60e64c0284 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -142,11 +142,12 @@ - api : reshape args : (Tensor x, ScalarArray shape) - output : Tensor + output : Tensor(out) infer_meta : func : ReshapeInferMeta kernel : func : reshape + inplace : (x -> out) - api : scale args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 7515981490..26abfdc031 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -48,10 +48,14 @@ class BaseAPI(object): self.support_selected_rows_kernel = False if len(self.kernel[ 'func']) == 1 else True self.data_transform = self.parse_data_transform(api_item_yaml) + self.inplace_map = self.parse_inplace(api_item_yaml) def get_api_name(self, api_item_yaml): return api_item_yaml['api'] + def get_api_func_name(self): + return self.api + def parse_args(self, api_name, api_item_yaml): inputs, attrs, args_str = self.parse_input_and_attr( api_name, api_item_yaml['args']) @@ -225,13 +229,37 @@ class BaseAPI(object): return data_transform + def parse_inplace(self, api_item_yaml): + if 'inplace' in api_item_yaml: + inplace_map = {} + inplace_list = api_item_yaml['inplace'].split(',') + for item in inplace_list: + result = re.search(r"(?P\w+)\s*->\s(?P\w+)", item) + in_val = result.group('in') + out_val = result.group('out') + assert in_val in self.inputs['names'], \ + f"{self.api} : Inplace input error: the input var name('{in_val}') is not found in the input args of {self.api}." + assert out_val in self.outputs['names'], \ + f"{self.api} : Inplace output error: the output var name('{out_val}') is not found in the output args of {self.api}." + + inplace_map[out_val] = in_val + + return inplace_map + else: + return None + # Override by child class def get_return_type(self, out_type_list): return None def gene_api_declaration(self): api_declaration = f""" -PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare']}); +PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']}); +""" + + if self.is_base_api and self.inplace_map is not None: + api_declaration = api_declaration + f""" +PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.args_str['args_declare']}); """ return api_declaration @@ -527,14 +555,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare return input_tensor_code, kernel_args[:-2], kernel_signature # Override by child class - def gene_output(self, output_type_list, set_out_func, code_indent): + def gene_output(self, + output_type_list, + set_out_func, + code_indent, + inplace_flag=False): return None, None, None - def gen_dense_tensor_kernel_code(self, code_indent): + def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False): input_tensors, kernel_args, kernel_signature = self.get_kernel_args( code_indent) outputs_args, kernel_output_names, output_create = self.gene_output( - self.outputs['types'], 'SetKernelOutput', code_indent) + self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag) return f""" {code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); @@ -552,11 +584,12 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare {code_indent} return out;""" - def gen_selected_rows_kernel_code(self, code_indent): + def gen_selected_rows_kernel_code(self, code_indent, inplace_flag=False): input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args( code_indent) outputs_args, kernel_output_names, output_create = self.gene_output( - self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent) + self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent, + inplace_flag) return f""" {code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} "{self.kernel['func'][1]}", {{kernel_backend, kernel_layout, kernel_data_type}}); @@ -574,32 +607,38 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare {code_indent} return out;""" - def gene_api_code(self): - if self.is_base_api: - api_code = f""" -PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str["args_define"]}) {{ + def gene_base_api_code(self, inplace_flag=False): + api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') + api_code = f""" +PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{ {self.gene_kernel_select()} """ - if self.support_selected_rows_kernel: - code_indent = ' ' - api_code = api_code + f""" + if self.support_selected_rows_kernel: + code_indent = ' ' + return api_code + f""" if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{ -{self.gen_dense_tensor_kernel_code(code_indent)} +{self.gen_dense_tensor_kernel_code(code_indent, inplace_flag)} }} else {{ -{self.gen_selected_rows_kernel_code(code_indent)} +{self.gen_selected_rows_kernel_code(code_indent, inplace_flag)} }} }} """ - return api_code - else: - code_indent = '' - return api_code + self.gen_dense_tensor_kernel_code( - code_indent) + """ + else: + code_indent = '' + return api_code + self.gen_dense_tensor_kernel_code( + code_indent, inplace_flag) + """ } """ + def gene_api_code(self): + if self.is_base_api: + api_code = self.gene_base_api_code() + if self.inplace_map is not None: + api_code = api_code + self.gene_base_api_code(inplace_flag=True) + return api_code + else: inveke_func_name = self.invoke.split('(')[0].strip() if inveke_func_name in self.attrs['names']: diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 2bdc5890a0..243cbc93b8 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -15,22 +15,38 @@ import os import yaml import argparse +import re from api_base import BaseAPI class ForwardAPI(BaseAPI): - prefix_tensor_name = 'dense_' - def __init__(self, api_item_yaml): super(ForwardAPI, self).__init__(api_item_yaml) + self.is_dygraph_api = self.parse_intermediate(api_item_yaml) + + def get_api_func_name(self): + if self.is_dygraph_api: + return self.api + '_intermediate' + else: + return self.api + + def parse_intermediate(self, api_item_yaml): + if 'intermediate' in api_item_yaml: + return True + else: + return False def get_return_type(self, out_type_list): return out_type_list[0] if len( out_type_list) == 1 else "std::tuple<" + ",".join( out_type_list) + ">" - def gene_output(self, output_type_list, set_out_func, code_indent): + def gene_output(self, + output_type_list, + set_out_func, + code_indent, + inplace_flag=False): kernel_output = "" output_names = [] output_create = "" @@ -38,8 +54,11 @@ class ForwardAPI(BaseAPI): if len(output_type_list) == 1: kernel_output = 'kernel_out' output_names.append('kernel_out') + inplace_assign = " = " + self.inplace_map[self.outputs['names'][ + 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][0] in self.inplace_map else "" output_create = f""" -{code_indent} {self.outputs['return_type']} out; +{code_indent} {self.outputs['return_type']} out{inplace_assign}; {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" elif len(output_type_list) > 1: @@ -49,6 +68,11 @@ class ForwardAPI(BaseAPI): for i in range(len(output_type_list)): kernel_output = kernel_output + f'kernel_out_{i}, ' output_names.append(f'kernel_out_{i}') + if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][i] in self.inplace_map: + output_create = output_create + f""" +{code_indent} std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};""" + output_create = output_create + f""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));""" @@ -110,12 +134,15 @@ namespace experimental { """) -def generate_api(api_yaml_path, header_file_path, source_file_path): +def generate_api(api_yaml_path, header_file_path, source_file_path, + dygraph_header_file_path, dygraph_source_file_path): with open(api_yaml_path, 'r') as f: apis = yaml.load(f, Loader=yaml.FullLoader) header_file = open(header_file_path, 'w') source_file = open(source_file_path, 'w') + dygraph_header_file = open(dygraph_header_file_path, 'w') + dygraph_source_file = open(dygraph_source_file_path, 'w') namespace = api_namespace() @@ -127,20 +154,37 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): source_file.write(source_include(include_header_file)) source_file.write(namespace[0]) + dygraph_header_file.write("#pragma once\n") + dygraph_header_file.write(header_include()) + dygraph_header_file.write(namespace[0]) + + dygraph_include_header_file = "paddle/pten/api/lib/dygraph_api.h" + dygraph_source_file.write(source_include(dygraph_include_header_file)) + dygraph_source_file.write(namespace[0]) + for api in apis: - api_code = ForwardAPI(api) - print(api_code.gene_api_declaration()) - header_file.write(api_code.gene_api_declaration()) - source_file.write(api_code.gene_api_code()) + foward_api = ForwardAPI(api) + if foward_api.is_dygraph_api: + dygraph_header_file.write(foward_api.gene_api_declaration()) + dygraph_source_file.write(foward_api.gene_api_code()) + else: + header_file.write(foward_api.gene_api_declaration()) + source_file.write(foward_api.gene_api_code()) header_file.write(namespace[1]) source_file.write(namespace[1]) + dygraph_header_file.write(namespace[1]) + dygraph_source_file.write(namespace[1]) + source_file.write(api_register()) header_file.close() source_file.close() + dygraph_header_file.close() + dygraph_source_file.close() + def main(): parser = argparse.ArgumentParser( @@ -149,6 +193,7 @@ def main(): '--api_yaml_path', help='path to api yaml file', default='python/paddle/utils/code_gen/api.yaml') + parser.add_argument( '--api_header_path', help='output of generated api header code file', @@ -159,13 +204,26 @@ def main(): help='output of generated api source code file', default='paddle/pten/api/lib/api.cc') + parser.add_argument( + '--dygraph_api_header_path', + help='output of generated dygraph api header code file', + default='paddle/pten/api/lib/dygraph_api.h') + + parser.add_argument( + '--dygraph_api_source_path', + help='output of generated dygraph api source code file', + default='paddle/pten/api/lib/dygraph_api.cc') + options = parser.parse_args() api_yaml_path = options.api_yaml_path header_file_path = options.api_header_path source_file_path = options.api_source_path + dygraph_header_file_path = options.dygraph_api_header_path + dygraph_source_file_path = options.dygraph_api_source_path - generate_api(api_yaml_path, header_file_path, source_file_path) + generate_api(api_yaml_path, header_file_path, source_file_path, + dygraph_header_file_path, dygraph_source_file_path) if __name__ == '__main__': diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index c63fb9bff0..e2e48e25ab 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -69,7 +69,11 @@ class BackwardAPI(BaseAPI): return out_type_list[0] if len( out_type_list) == 1 else "std::vector>" - def gene_output(self, output_type_list, set_out_func, code_indent): + def gene_output(self, + output_type_list, + set_out_func, + code_indent, + inplace_flag=False): kernel_output = "" output_names = [] output_create = "" @@ -77,8 +81,11 @@ class BackwardAPI(BaseAPI): if len(output_type_list) == 1: kernel_output = 'kernel_out' output_names.append('kernel_out') + inplace_assign = " = " + self.inplace_map[self.outputs['names'][ + 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][0] in self.inplace_map else "" output_create = f""" -{code_indent} {self.outputs['return_type']} out; +{code_indent} {self.outputs['return_type']} out{inplace_assign}; {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" elif len(output_type_list) > 1: @@ -90,11 +97,22 @@ class BackwardAPI(BaseAPI): output_names.append(f'kernel_out_{i}') if out_type_item == 'Tensor': get_out_code = f'&out[{i}][0]' - output_create = output_create + f""" + if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][i] in self.inplace_map: + output_create = output_create + f""" +{code_indent} out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});""" + + else: + output_create = output_create + f""" {code_indent} out[{i}].emplace_back();""" else: get_out_code = f'&out[{i}]' + if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][i] in self.inplace_map: + output_create = output_create + f""" +{code_indent} out[{i}] = {self.inplace_map[self.outputs['names'][i]]};""" + output_create = output_create + f""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});""" diff --git a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py index 6a434b60e6..6972b9af25 100644 --- a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py +++ b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py @@ -16,7 +16,7 @@ import os import yaml import argparse -from api_base import BaseAPI +from api_gen import ForwardAPI def get_wrapped_infermeta_name(api_name): @@ -24,7 +24,7 @@ def get_wrapped_infermeta_name(api_name): def gene_wrapped_infermeta_and_register(api): - if api.is_base_api: + if api.is_base_api and not api.is_dygraph_api: register_code = f""" PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" @@ -76,20 +76,6 @@ PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_ return '', '', '' -def gene_infermeta_register(api): - if api.is_base_api: - if api.infer_meta['param'] is None: - return f""" -PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" - - else: - return f""" -PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_name(api.kernel['func'][0])});""" - - else: - return '' - - def header_include(): return """ #include "paddle/pten/core/meta_tensor.h" @@ -138,7 +124,7 @@ def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path, infermeta_register_code = '' for api in apis: - api_item = BaseAPI(api) + api_item = ForwardAPI(api) declare_code, defind_code, register_code = gene_wrapped_infermeta_and_register( api_item) header_file.write(declare_code) -- GitLab