From 5bb3b66834a1038e6f10a92ccf228f2d2a3b922a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 15 Feb 2022 10:50:17 +0800 Subject: [PATCH] [Pten] Support SelectedRows in C++ API (#39497) * add data_transform in pten api * support GetKernelTypeForVar * fix complie problem of bfloat16 * add scale_sr in api * suppport select_row in C++ api * merge code --- paddle/pten/api/lib/api_utils.h | 17 ++ paddle/pten/api/lib/kernel_dispatch.h | 11 + paddle/pten/api/lib/tensor.cc | 8 +- paddle/pten/core/meta_tensor.cc | 24 ++ .../kernels/selected_rows/scale_kernel.cc | 3 +- paddle/pten/kernels/transfer_layout_kernel.h | 8 +- paddle/pten/tests/api/test_scale_api.cc | 29 +- python/paddle/utils/code_gen/api.yaml | 2 +- python/paddle/utils/code_gen/api_base.py | 258 ++++++++++++++---- python/paddle/utils/code_gen/api_gen.py | 18 +- .../paddle/utils/code_gen/backward_api_gen.py | 22 +- .../utils/code_gen/wrapped_infermeta_gen.py | 8 +- 12 files changed, 312 insertions(+), 96 deletions(-) diff --git a/paddle/pten/api/lib/api_utils.h b/paddle/pten/api/lib/api_utils.h index 765a81b726..1df3b5964f 100644 --- a/paddle/pten/api/lib/api_utils.h +++ b/paddle/pten/api/lib/api_utils.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/pten/core/compat/convert_utils.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/meta_tensor.h" +#include "paddle/pten/core/selected_rows.h" namespace paddle { namespace experimental { @@ -43,6 +44,11 @@ inline std::unique_ptr> TensorToDenseTensor( return std::move(pt_tensors); } +inline std::shared_ptr TensorToSelectedRows( + const Tensor& tensor) { + return std::dynamic_pointer_cast(tensor.impl()); +} + /* ----------------- for infer_meta --------------------- */ inline pten::MetaTensor MakeMetaTensor(const pten::DenseTensor& tensor) { @@ -59,6 +65,10 @@ inline std::vector MakeMetaTensor( return meta_tensors; } +inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) { + return pten::MetaTensor(tensor); +} + /* ------------------ for output ----------------------- */ inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { @@ -84,5 +94,12 @@ inline std::vector SetKernelOutput( return results; } +inline pten::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, + Tensor* out) { + auto select_rows = std::make_shared(); + out->set_impl(select_rows); + return select_rows.get(); +} + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/kernel_dispatch.h b/paddle/pten/api/lib/kernel_dispatch.h index 9c83e0d3d4..de753669af 100644 --- a/paddle/pten/api/lib/kernel_dispatch.h +++ b/paddle/pten/api/lib/kernel_dispatch.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/pten/backends/all_context.h" #include "paddle/pten/common/data_type.h" #include "paddle/pten/common/layout.h" +#include "paddle/pten/core/selected_rows.h" // TODO(chenweihang): split Key, Kernel, Factory into diff files #include "paddle/pten/core/kernel_factory.h" @@ -38,8 +39,15 @@ std::size_t CountLeadingZeros(uint64_t val); pten::DeviceContext* GetDeviceContextByBackend(pten::Backend backend); +enum class KernelType { + DENSE_TENSOR_KENREL, // kernel for DenseTensor + SELECTED_ROWS_KENREL // kernel for SelectedRows +}; + // TODO(chenweihang): support DataLayout and DataType selected struct KernelKeySet { + KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL}; + BackendSet backend_set{Backend::UNDEFINED}; DataLayout layout{DataLayout::UNDEFINED}; DataType dtype{DataType::UNDEFINED}; @@ -89,6 +97,9 @@ struct KernelKeyParser : ArgsIterator { void operator()(const Tensor& x) { key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x); // TODO(chenweihang): selecte multi layout and dtype + if (pten::SelectedRows::classof(x.impl().get())) { + key_set.kernel_type = KernelType::SELECTED_ROWS_KENREL; + } key_set.layout = x.layout(); key_set.dtype = x.type(); dtype_set = dtype_set | DataTypeSet(x.dtype()); diff --git a/paddle/pten/api/lib/tensor.cc b/paddle/pten/api/lib/tensor.cc index f1a54ee960..6fb0d2706c 100644 --- a/paddle/pten/api/lib/tensor.cc +++ b/paddle/pten/api/lib/tensor.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/compat/convert_utils.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/selected_rows.h" #include "paddle/pten/core/tensor_base.h" #include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_utils.h" @@ -222,8 +223,11 @@ Tensor::mutable_data(const PlaceType &place); template const T *Tensor::data() const { if (is_dense_tensor()) { - return std::dynamic_pointer_cast(impl_)->mutable_data( - ConvertExtPlaceToInnerPlace(place())); + return std::dynamic_pointer_cast(impl_)->data(); + } else if (pten::SelectedRows::classof(impl_.get())) { + return std::dynamic_pointer_cast(impl_) + ->value() + .data(); } return nullptr; } diff --git a/paddle/pten/core/meta_tensor.cc b/paddle/pten/core/meta_tensor.cc index d205ee1ca4..383aa0487d 100644 --- a/paddle/pten/core/meta_tensor.cc +++ b/paddle/pten/core/meta_tensor.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/enforce.h" +#include "paddle/pten/core/selected_rows.h" #include "paddle/pten/core/tensor_utils.h" namespace pten { @@ -32,6 +33,10 @@ void MetaTensor::set_dims(const DDim& dims) { if (pten::DenseTensor::classof(tensor_)) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_))->dims = dims; + } else if (pten::SelectedRows::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_)->mutable_value()) + ->dims = dims; } else { PADDLE_THROW(pten::errors::Unimplemented( "Unsupported setting dims for `%s`.", tensor_->type_info().name())); @@ -42,6 +47,10 @@ void MetaTensor::set_dtype(DataType dtype) { if (pten::DenseTensor::classof(tensor_)) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) ->dtype = dtype; + } else if (pten::SelectedRows::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_)->mutable_value()) + ->dtype = dtype; } else { PADDLE_THROW(pten::errors::Unimplemented( "Unsupported settting dtype for `%s`.", tensor_->type_info().name())); @@ -52,6 +61,10 @@ void MetaTensor::set_layout(DataLayout layout) { if (pten::DenseTensor::classof(tensor_)) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) ->layout = layout; + } else if (pten::SelectedRows::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_)->mutable_value()) + ->layout = layout; } else { PADDLE_THROW(pten::errors::Unimplemented( "Unsupported settting layout for `%s`.", tensor_->type_info().name())); @@ -62,6 +75,10 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) { if (pten::DenseTensor::classof(tensor_)) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_))->lod = meta_tensor.lod(); + } else if (pten::SelectedRows::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta( + static_cast(tensor_)->mutable_value()) + ->lod = meta_tensor.lod(); } else { PADDLE_THROW( pten::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.", @@ -72,6 +89,8 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) { const LoD& MetaTensor::lod() const { if (pten::DenseTensor::classof(tensor_)) { return static_cast(tensor_)->lod(); + } else if (pten::SelectedRows::classof(tensor_)) { + return static_cast(tensor_)->value().lod(); } else { PADDLE_THROW(pten::errors::Unimplemented("Unsupported getting lod of `%s`.", tensor_->type_info().name())); @@ -84,6 +103,11 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) { set_dtype(meta_tensor.dtype()); set_layout(meta_tensor.layout()); share_lod(meta_tensor); + } else if (pten::SelectedRows::classof(tensor_)) { + set_dims(meta_tensor.dims()); + set_dtype(meta_tensor.dtype()); + set_layout(meta_tensor.layout()); + share_lod(meta_tensor); } else { PADDLE_THROW(pten::errors::Unimplemented( "Unsupported sharing meta for `%s`.", tensor_->type_info().name())); diff --git a/paddle/pten/kernels/selected_rows/scale_kernel.cc b/paddle/pten/kernels/selected_rows/scale_kernel.cc index 8b29f1d6c5..09700d8afe 100644 --- a/paddle/pten/kernels/selected_rows/scale_kernel.cc +++ b/paddle/pten/kernels/selected_rows/scale_kernel.cc @@ -29,7 +29,8 @@ void ScaleSR(const Context& dev_ctx, float bias, bool bias_after_scale, SelectedRows* out) { - if (x.value().data() != out->value().data()) { + if (x.value().Holder() != out->value().Holder() || + x.value().data() != out->value().data()) { out->set_rows(x.rows()); out->set_height(x.height()); } diff --git a/paddle/pten/kernels/transfer_layout_kernel.h b/paddle/pten/kernels/transfer_layout_kernel.h index 24854842e8..6e1b434fec 100644 --- a/paddle/pten/kernels/transfer_layout_kernel.h +++ b/paddle/pten/kernels/transfer_layout_kernel.h @@ -30,12 +30,8 @@ template DenseTensor TransferLayout(const Context& dev_ctx, const DenseTensor& x, DataLayout dst_layout) { - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - {x.dtype(), x.dims(), dst_layout}); - MetaTensor meta_out(&dense_out); - TransferLayoutInferMeta(x, dst_layout, &meta_out); + pten::DenseTensor dense_out = + pten::Empty(dev_ctx, {x.dtype(), x.dims(), dst_layout}); TransferLayoutKernel(dev_ctx, x, dst_layout, &dense_out); return dense_out; } diff --git a/paddle/pten/tests/api/test_scale_api.cc b/paddle/pten/tests/api/test_scale_api.cc index bb5523d26c..77c5f1b44f 100644 --- a/paddle/pten/tests/api/test_scale_api.cc +++ b/paddle/pten/tests/api/test_scale_api.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/selected_rows.h" namespace paddle { namespace tests { @@ -26,7 +27,7 @@ namespace tests { namespace framework = paddle::framework; using DDim = pten::framework::DDim; -void CheckScaleResult(experimental::Tensor* out) { +void CheckScaleResult(const experimental::Tensor* out) { ASSERT_EQ(out->dims().size(), 2); ASSERT_EQ(out->dims()[0], 3); ASSERT_EQ(out->dims()[1], 4); @@ -36,7 +37,7 @@ void CheckScaleResult(experimental::Tensor* out) { ASSERT_EQ(out->layout(), pten::DataLayout::NCHW); ASSERT_EQ(out->initialized(), true); for (int64_t i = 0; i < out->numel(); ++i) { - ASSERT_EQ(out->mutable_data()[i], 3.0); + ASSERT_NEAR(3.0, out->data()[i], 1e-6f); } } @@ -52,5 +53,29 @@ TEST(API, scale) { CheckScaleResult(&out2); } +TEST(API, scale_sr) { + // 1. check `scale` is float value + std::vector rows{0, 4, 7}; + int64_t height = 10; + auto selected_rows = std::make_shared(rows, height); + auto dense_tensor = std::dynamic_pointer_cast( + experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32).impl()); + *(selected_rows->mutable_value()) = *dense_tensor; + experimental::Tensor x(selected_rows); + const auto out = experimental::scale(x, 2.0, 1.0, true); + + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.dims()[1], 4); + ASSERT_EQ(out.numel(), 12); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + for (int64_t i = 0; i < out.numel(); ++i) { + ASSERT_NEAR(3.0, out.data()[i], 1e-6f); + } +} + } // namespace tests } // namespace paddle diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 6f64eaadc8..8b8b001739 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -155,7 +155,7 @@ func : UnchangedInferMeta param : [x] kernel : - func : scale + func : scale, scale_sr - api : sign args : (const Tensor& x) diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 05c8861dcf..2e1ed58e1c 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -14,7 +14,7 @@ import re -PREFIX_TENSOR_NAME = 'dense_' +PREFIX_TENSOR_NAME = 'input_' PREFIX_META_TENSOR_NAME = 'meta_' @@ -33,8 +33,8 @@ class BaseAPI(object): # types : [], list of output types # return_type : Tensor, vector, ..., the return type of api # args_str: - # args_declare : "str" // str of funtion params with default value. Example: (..., bool flag=false) - # args_define : "str" // str of funtion params without default value. Example: (..., bool flag) + # args_declare : "str" // str of function params with default value. Example: (..., bool flag=false) + # args_define : "str" // str of function params without default value. Example: (..., bool flag) self.inputs, self.attrs, self.outputs, self.args_str = self.parse_args( self.api, api_item_yaml) @@ -43,32 +43,11 @@ class BaseAPI(object): self.is_base_api = False self.invoke = api_item_yaml['invoke'] else: - self.kernel = api_item_yaml['kernel'] - if 'backend' not in self.kernel or len(self.kernel['backend']) == 0: - self.kernel['backend'] = None - if 'layout' not in self.kernel or len(self.kernel['layout']) == 0: - self.kernel['layout'] = None - if 'data_type' not in self.kernel or len(self.kernel[ - 'data_type']) == 0: - self.kernel['data_type'] = None - if 'param' not in self.kernel: - self.kernel['param'] = None - - self.infer_meta = api_item_yaml['infer_meta'] - if 'param' not in self.infer_meta: - self.infer_meta['param'] = None - - self.data_transform = { - 'skip_transform': [], - 'support_trans_dtype': [] - } - if 'data_transform' in api_item_yaml: - if 'skip_transform' in api_item_yaml['data_transform']: - self.data_transform['skip_transform'] = api_item_yaml[ - 'data_transform']['skip_transform'] - if 'support_trans_dtype' in api_item_yaml['data_transform']: - self.data_transform['support_trans_dtype'] = api_item_yaml[ - 'data_transform']['support_trans_dtype'] + self.infer_meta = self.parse_infer_meta(api_item_yaml['infer_meta']) + self.kernel = self.parse_kernel(api_item_yaml['kernel']) + self.support_selected_rows_kernel = False if len(self.kernel[ + 'func']) == 1 else True + self.data_transform = self.parse_data_transform(api_item_yaml) def get_api_name(self, api_item_yaml): return api_item_yaml['api'] @@ -185,6 +164,61 @@ class BaseAPI(object): return out_type_list, out_name_list, self.get_return_type( out_type_list) + def parse_infer_meta(self, infer_meta_config): + infer_meta = infer_meta_config + if 'param' not in infer_meta_config: + infer_meta['param'] = None + + return infer_meta + + def parse_kernel(self, kernel_config): + # kernel : + # func : [], Kernel functions (example: scale, scale_sr) + # param : [], Input params of kernel + # backend : str, the names of param to choose the kernel backend, default is None + # layout : str, the names of param to choose the kernel layout, default is None + # data_type : str, the names of param to choose the kernel data_type, default is None + kernel = { + 'func': [], + 'param': None, + 'backend': None, + 'layout': None, + 'data_type': None + } + if 'backend' in kernel_config and len(kernel_config['backend']) > 0: + kernel['backend'] = kernel_config['backend'] + if 'layout' in kernel_config and len(kernel_config['layout']) > 0: + kernel['layout'] = kernel_config['layout'] + if 'data_type' in kernel_config and len(kernel_config['data_type']) > 0: + kernel['data_type'] = kernel_config['data_type'] + if 'param' in kernel_config: + kernel['param'] = kernel_config['param'] + kernel['func'] = [ + kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',') + ] + + if len(kernel['func']) == 2: + assert kernel['func'][0] == self.api, \ + f"{self.api} : Kernel func error: If kernel has two func config, the name of first func should be same with api name({self.api}), \ + but now is {kernel['func'][0]}." + assert kernel['func'][1].endswith('_sr'), \ + f"{self.api} : Kernel func error: If kernel has two func config, the name of second func should be a selected_rows kernel (the func name endwith '_sr'), \ + but now is {kernel['func'][1]}." + + return kernel + + def parse_data_transform(self, api_item_yaml): + data_transform = {'skip_transform': [], 'support_trans_dtype': []} + if 'data_transform' in api_item_yaml: + if 'skip_transform' in api_item_yaml['data_transform']: + data_transform['skip_transform'] = api_item_yaml[ + 'data_transform']['skip_transform'] + if 'support_trans_dtype' in api_item_yaml['data_transform']: + data_transform['support_trans_dtype'] = api_item_yaml[ + 'data_transform']['support_trans_dtype'] + + return data_transform + # Override by child class def get_return_type(self, out_type_list): return None @@ -303,12 +337,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare kernel_select_code = kernel_key_item_init + kernel_select_code if len(input_names) > 0: + if self.support_selected_rows_kernel: + kernel_select_code = kernel_select_code + f""" + KernelType kernel_type; +""" + kernel_select_code = kernel_select_code + f""" if (kernel_backend == Backend::UNDEFINED || kernel_layout == DataLayout::UNDEFINED || kernel_data_type == DataType::UNDEFINED ) {{ auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + {'kernel_type = kernel_key_set.kernel_type;' if self.support_selected_rows_kernel else ''} if (kernel_backend == Backend::UNDEFINED) {{ kernel_backend = kernel_key.backend(); }} @@ -320,15 +360,9 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare }} }}""" - kernel_select_code = kernel_select_code + f""" - auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( - "{kernel['func']}", {{kernel_backend, kernel_layout, kernel_data_type}}); - VLOG(6) << "{api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; - VLOG(6) << "{api} API kernel: " << kernel;""" - return kernel_select_code - def gene_infer_meta(self, kernel_output_names) -> str: + def gene_infer_meta(self, kernel_output_names, code_indent) -> str: input_names = self.inputs['names'] attr_names = self.attrs['names'] infer_meta = self.infer_meta @@ -343,11 +377,10 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare if param in input_names: param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), " elif param in kernel_output_names: - meta_tensor_code = meta_tensor_code + " pten::MetaTensor " + param.replace( - PREFIX_TENSOR_NAME, - PREFIX_META_TENSOR_NAME) + "(" + param + ");\n" + meta_tensor_code = meta_tensor_code + code_indent + " pten::MetaTensor " + param.replace( + 'kernel_', PREFIX_META_TENSOR_NAME) + "(" + param + ");\n" param_code = param_code + "&" + param.replace( - PREFIX_TENSOR_NAME, PREFIX_META_TENSOR_NAME) + ", " + 'kernel_', PREFIX_META_TENSOR_NAME) + ", " elif param in attr_names: param_code = param_code + param + ", " elif isinstance(param, str): @@ -359,10 +392,10 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare param_code = param_code[:-2] return f"""{meta_tensor_code} - pten::{infer_meta['func']}({param_code}); +{code_indent} pten::{infer_meta['func']}({param_code}); """ - def get_kernel_args(self): + def get_kernel_args(self, code_indent): input_trans_map = { 'const Tensor&': 'const pten::DenseTensor&', 'const Tensor &': 'const pten::DenseTensor&', @@ -394,11 +427,69 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare elif input_name in self.data_transform['support_trans_dtype']: trans_flag = "{false, true}" input_tensor_code = input_tensor_code + f""" - auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" else: input_tensor_code = input_tensor_code + f""" - auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});""" + + kernel_args = "*dev_ctx, " + for param in kernel_param: + if param in input_names: + kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", " + kernel_args_type_list.append(input_trans_map[input_infos[ + param]]) + elif param in attr_names: + # set attr for kernel_context + if 'ScalarArray' in self.attrs['attr_info'][param][0]: + kernel_args_type_list.append('const pten::ScalarArray&') + param = 'pten::ScalarArray(' + param + ')' + elif 'Scalar' in self.attrs['attr_info'][param][0]: + kernel_args_type_list.append('const pten::Scalar&') + param = 'pten::Scalar(' + param + ')' + else: + kernel_args_type_list.append(self.attrs['attr_info'][param][ + 0]) + kernel_args = kernel_args + param + ", " + elif isinstance(param, bool): + kernel_args = kernel_args + str(param).lower() + ", " + else: + kernel_args = kernel_args + str(param) + ", " + + for out_type in self.outputs['types']: + kernel_args_type_list.append(out_trans_map[out_type]) + + kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")" + + return input_tensor_code, kernel_args[:-2], kernel_signature + + def get_selected_rows_kernel_args(self, code_indent): + input_trans_map = { + 'const Tensor&': 'const pten::SelectedRows&', + 'const Tensor &': 'const pten::SelectedRows&' + } + out_trans_map = {'Tensor': 'pten::SelectedRows*'} + input_names = self.inputs['names'] + input_infos = self.inputs['input_info'] + kernel_args_type_list = ['const platform::DeviceContext&'] + + input_tensor_code = "" + for input_name in input_names: + # set input code + input_tensor_code = input_tensor_code + f""" + auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});""" + + attr_names = self.attrs['names'] + + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + input_tensor_code = "" + for i, input_name in enumerate(input_names): + # set input code + input_tensor_code = input_tensor_code + f""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});""" kernel_args = "*dev_ctx, " for param in kernel_param: @@ -431,30 +522,77 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare return input_tensor_code, kernel_args[:-2], kernel_signature # Override by child class - def gene_output(self, output_type_list): + def gene_output(self, output_type_list, set_out_func, code_indent): return None, None, None + def gen_dense_tensor_kernel_code(self, code_indent): + input_tensors, kernel_args, kernel_signature = self.get_kernel_args( + code_indent) + outputs_args, kernel_output_names, output_create = self.gene_output( + self.outputs['types'], 'SetKernelOutput', code_indent) + return f""" +{code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( +{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); +{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; +{code_indent} VLOG(6) << "{self.api} API kernel: " << kernel; + +{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); +{input_tensors} +{output_create} +{self.gene_infer_meta(kernel_output_names, code_indent)} + +{code_indent} using kernel_signature = {kernel_signature}; +{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn(); +{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); + +{code_indent} return out;""" + + def gen_selected_rows_kernel_code(self, code_indent): + input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args( + code_indent) + outputs_args, kernel_output_names, output_create = self.gene_output( + self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent) + return f""" +{code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( +{code_indent} "{self.kernel['func'][1]}", {{kernel_backend, kernel_layout, kernel_data_type}}); +{code_indent} VLOG(6) << "{self.api} API SelectedRows kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; +{code_indent} VLOG(6) << "{self.api} API SelectedRows kernel: " << kernel; + +{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); +{input_tensors} +{output_create} +{self.gene_infer_meta(kernel_output_names, code_indent)} + +{code_indent} using kernel_signature = {kernel_signature}; +{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn(); +{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); + +{code_indent} return out;""" + def gene_api_code(self): if self.is_base_api: - input_tensors, kernel_args, kernel_signature = self.get_kernel_args( - ) - outputs_args, kernel_output_names, output_create = self.gene_output( - self.outputs['types']) - return f""" + api_code = f""" PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str["args_define"]}) {{ {self.gene_kernel_select()} +""" - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); -{input_tensors} -{output_create} -{self.gene_infer_meta(kernel_output_names)} - - using kernel_signature = {kernel_signature}; - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)({kernel_args}, {outputs_args}); - - return out; + if self.support_selected_rows_kernel: + code_indent = ' ' + api_code = api_code + f""" + if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{ +{self.gen_dense_tensor_kernel_code(code_indent)} + }} else {{ +{self.gen_selected_rows_kernel_code(code_indent)} + }} }} +""" + + return api_code + else: + code_indent = '' + return api_code + self.gen_dense_tensor_kernel_code( + code_indent) + """ +} """ else: diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 629d68230a..2bdc5890a0 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -30,27 +30,27 @@ class ForwardAPI(BaseAPI): out_type_list) == 1 else "std::tuple<" + ",".join( out_type_list) + ">" - def gene_output(self, output_type_list): + def gene_output(self, output_type_list, set_out_func, code_indent): kernel_output = "" output_names = [] output_create = "" if len(output_type_list) == 1: - kernel_output = 'dense_out' - output_names.append('dense_out') + kernel_output = 'kernel_out' + output_names.append('kernel_out') output_create = f""" - {self.outputs['return_type']} out; - auto dense_out = SetKernelOutput(kernel_backend, &out);""" +{code_indent} {self.outputs['return_type']} out; +{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" elif len(output_type_list) > 1: output_create = f""" - {self.outputs['return_type']} out;""" +{code_indent} {self.outputs['return_type']} out;""" for i in range(len(output_type_list)): - kernel_output = kernel_output + f'dense_out_{i}, ' - output_names.append(f'dense_out_{i}') + kernel_output = kernel_output + f'kernel_out_{i}, ' + output_names.append(f'kernel_out_{i}') output_create = output_create + f""" - auto dense_out_{i} = SetKernelOutput(kernel_backend, &std::get<{i}>(out));""" +{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));""" kernel_output = kernel_output[:-2] else: diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index 96fabfc3db..c63fb9bff0 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -62,41 +62,41 @@ class BackwardAPI(BaseAPI): # check the output of backward assert len(self.outputs['types']) <= len(fw_inputs['names']), \ - f"{self.api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \ + f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \ Please check the output of {self.api} in yaml." def get_return_type(self, out_type_list): return out_type_list[0] if len( out_type_list) == 1 else "std::vector>" - def gene_output(self, output_type_list): + def gene_output(self, output_type_list, set_out_func, code_indent): kernel_output = "" output_names = [] output_create = "" if len(output_type_list) == 1: - kernel_output = 'dense_out' - output_names.append('dense_out') + kernel_output = 'kernel_out' + output_names.append('kernel_out') output_create = f""" - {self.outputs['return_type']} out; - auto dense_out = SetKernelOutput(kernel_backend, &out);""" +{code_indent} {self.outputs['return_type']} out; +{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" elif len(output_type_list) > 1: output_create = f""" - {self.outputs['return_type']} out({len(output_type_list)});""" +{code_indent} {self.outputs['return_type']} out({len(output_type_list)});""" for i, out_type_item in enumerate(output_type_list): - kernel_output = kernel_output + f'dense_out_{i}, ' - output_names.append(f'dense_out_{i}') + kernel_output = kernel_output + f'kernel_out_{i}, ' + output_names.append(f'kernel_out_{i}') if out_type_item == 'Tensor': get_out_code = f'&out[{i}][0]' output_create = output_create + f""" - out[{i}].emplace_back();""" +{code_indent} out[{i}].emplace_back();""" else: get_out_code = f'&out[{i}]' output_create = output_create + f""" - auto dense_out_{i} = SetKernelOutput(kernel_backend, {get_out_code});""" +{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});""" kernel_output = kernel_output[:-2] else: diff --git a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py index ad26062e6b..6a434b60e6 100644 --- a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py +++ b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py @@ -26,7 +26,7 @@ def get_wrapped_infermeta_name(api_name): def gene_wrapped_infermeta_and_register(api): if api.is_base_api: register_code = f""" -PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});""" +PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" if api.infer_meta['param'] is not None: tensor_type_map = { @@ -67,7 +67,7 @@ void {wrapped_infermeta_name}({", ".join(args)}) {{ """ register_code = f""" -PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});""" +PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_name(api.kernel['func'][0])});""" return declare_code, defind_code, register_code else: @@ -80,11 +80,11 @@ def gene_infermeta_register(api): if api.is_base_api: if api.infer_meta['param'] is None: return f""" -PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});""" +PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" else: return f""" -PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});""" +PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_name(api.kernel['func'][0])});""" else: return '' -- GitLab