未验证 提交 5bb3b668 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Support SelectedRows in C++ API (#39497)

* add data_transform in pten api

* support GetKernelTypeForVar

* fix complie problem of bfloat16

* add scale_sr in api

* suppport select_row in C++ api

* merge code
上级 9fd67ffe
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/selected_rows.h"
namespace paddle {
namespace experimental {
......@@ -43,6 +44,11 @@ inline std::unique_ptr<std::vector<pten::DenseTensor>> TensorToDenseTensor(
return std::move(pt_tensors);
}
inline std::shared_ptr<pten::SelectedRows> TensorToSelectedRows(
const Tensor& tensor) {
return std::dynamic_pointer_cast<pten::SelectedRows>(tensor.impl());
}
/* ----------------- for infer_meta --------------------- */
inline pten::MetaTensor MakeMetaTensor(const pten::DenseTensor& tensor) {
......@@ -59,6 +65,10 @@ inline std::vector<pten::MetaTensor> MakeMetaTensor(
return meta_tensors;
}
inline pten::MetaTensor MakeMetaTensor(const pten::SelectedRows& tensor) {
return pten::MetaTensor(tensor);
}
/* ------------------ for output ----------------------- */
inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
......@@ -84,5 +94,12 @@ inline std::vector<pten::DenseTensor*> SetKernelOutput(
return results;
}
inline pten::SelectedRows* SetSelectedRowsKernelOutput(Backend backend,
Tensor* out) {
auto select_rows = std::make_shared<pten::SelectedRows>();
out->set_impl(select_rows);
return select_rows.get();
}
} // namespace experimental
} // namespace paddle
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
#include "paddle/pten/core/selected_rows.h"
// TODO(chenweihang): split Key, Kernel, Factory into diff files
#include "paddle/pten/core/kernel_factory.h"
......@@ -38,8 +39,15 @@ std::size_t CountLeadingZeros(uint64_t val);
pten::DeviceContext* GetDeviceContextByBackend(pten::Backend backend);
enum class KernelType {
DENSE_TENSOR_KENREL, // kernel for DenseTensor
SELECTED_ROWS_KENREL // kernel for SelectedRows
};
// TODO(chenweihang): support DataLayout and DataType selected
struct KernelKeySet {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};
BackendSet backend_set{Backend::UNDEFINED};
DataLayout layout{DataLayout::UNDEFINED};
DataType dtype{DataType::UNDEFINED};
......@@ -89,6 +97,9 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
void operator()(const Tensor& x) {
key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x);
// TODO(chenweihang): selecte multi layout and dtype
if (pten::SelectedRows::classof(x.impl().get())) {
key_set.kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
key_set.layout = x.layout();
key_set.dtype = x.type();
dtype_set = dtype_set | DataTypeSet(x.dtype());
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/core/tensor_base.h"
#include "paddle/pten/core/tensor_meta.h"
#include "paddle/pten/core/tensor_utils.h"
......@@ -222,8 +223,11 @@ Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template <typename T>
const T *Tensor::data() const {
if (is_dense_tensor()) {
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->mutable_data<T>(
ConvertExtPlaceToInnerPlace(place()));
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->data<T>();
} else if (pten::SelectedRows::classof(impl_.get())) {
return std::dynamic_pointer_cast<pten::SelectedRows>(impl_)
->value()
.data<T>();
}
return nullptr;
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/core/tensor_utils.h"
namespace pten {
......@@ -32,6 +33,10 @@ void MetaTensor::set_dims(const DDim& dims) {
if (pten::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
dims;
} else if (pten::SelectedRows::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->dims = dims;
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported setting dims for `%s`.", tensor_->type_info().name()));
......@@ -42,6 +47,10 @@ void MetaTensor::set_dtype(DataType dtype) {
if (pten::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->dtype = dtype;
} else if (pten::SelectedRows::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->dtype = dtype;
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
......@@ -52,6 +61,10 @@ void MetaTensor::set_layout(DataLayout layout) {
if (pten::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->layout = layout;
} else if (pten::SelectedRows::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->layout = layout;
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported settting layout for `%s`.", tensor_->type_info().name()));
......@@ -62,6 +75,10 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
if (pten::DenseTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
meta_tensor.lod();
} else if (pten::SelectedRows::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<SelectedRows*>(tensor_)->mutable_value())
->lod = meta_tensor.lod();
} else {
PADDLE_THROW(
pten::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
......@@ -72,6 +89,8 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
const LoD& MetaTensor::lod() const {
if (pten::DenseTensor::classof(tensor_)) {
return static_cast<DenseTensor*>(tensor_)->lod();
} else if (pten::SelectedRows::classof(tensor_)) {
return static_cast<SelectedRows*>(tensor_)->value().lod();
} else {
PADDLE_THROW(pten::errors::Unimplemented("Unsupported getting lod of `%s`.",
tensor_->type_info().name()));
......@@ -84,6 +103,11 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
share_lod(meta_tensor);
} else if (pten::SelectedRows::classof(tensor_)) {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
share_lod(meta_tensor);
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
......
......@@ -29,7 +29,8 @@ void ScaleSR(const Context& dev_ctx,
float bias,
bool bias_after_scale,
SelectedRows* out) {
if (x.value().data() != out->value().data()) {
if (x.value().Holder() != out->value().Holder() ||
x.value().data() != out->value().data()) {
out->set_rows(x.rows());
out->set_height(x.height());
}
......
......@@ -30,12 +30,8 @@ template <typename Context>
DenseTensor TransferLayout(const Context& dev_ctx,
const DenseTensor& x,
DataLayout dst_layout) {
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
{x.dtype(), x.dims(), dst_layout});
MetaTensor meta_out(&dense_out);
TransferLayoutInferMeta(x, dst_layout, &meta_out);
pten::DenseTensor dense_out =
pten::Empty(dev_ctx, {x.dtype(), x.dims(), dst_layout});
TransferLayoutKernel<Context>(dev_ctx, x, dst_layout, &dense_out);
return dense_out;
}
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"
namespace paddle {
namespace tests {
......@@ -26,7 +27,7 @@ namespace tests {
namespace framework = paddle::framework;
using DDim = pten::framework::DDim;
void CheckScaleResult(experimental::Tensor* out) {
void CheckScaleResult(const experimental::Tensor* out) {
ASSERT_EQ(out->dims().size(), 2);
ASSERT_EQ(out->dims()[0], 3);
ASSERT_EQ(out->dims()[1], 4);
......@@ -36,7 +37,7 @@ void CheckScaleResult(experimental::Tensor* out) {
ASSERT_EQ(out->layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out->initialized(), true);
for (int64_t i = 0; i < out->numel(); ++i) {
ASSERT_EQ(out->mutable_data<float>()[i], 3.0);
ASSERT_NEAR(3.0, out->data<float>()[i], 1e-6f);
}
}
......@@ -52,5 +53,29 @@ TEST(API, scale) {
CheckScaleResult(&out2);
}
TEST(API, scale_sr) {
// 1. check `scale` is float value
std::vector<int64_t> rows{0, 4, 7};
int64_t height = 10;
auto selected_rows = std::make_shared<pten::SelectedRows>(rows, height);
auto dense_tensor = std::dynamic_pointer_cast<pten::DenseTensor>(
experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32).impl());
*(selected_rows->mutable_value()) = *dense_tensor;
experimental::Tensor x(selected_rows);
const auto out = experimental::scale(x, 2.0, 1.0, true);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 4);
ASSERT_EQ(out.numel(), 12);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
for (int64_t i = 0; i < out.numel(); ++i) {
ASSERT_NEAR(3.0, out.data<float>()[i], 1e-6f);
}
}
} // namespace tests
} // namespace paddle
......@@ -155,7 +155,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : scale
func : scale, scale_sr
- api : sign
args : (const Tensor& x)
......
......@@ -14,7 +14,7 @@
import re
PREFIX_TENSOR_NAME = 'dense_'
PREFIX_TENSOR_NAME = 'input_'
PREFIX_META_TENSOR_NAME = 'meta_'
......@@ -33,8 +33,8 @@ class BaseAPI(object):
# types : [], list of output types
# return_type : Tensor, vector<Tensor>, ..., the return type of api
# args_str:
# args_declare : "str" // str of funtion params with default value. Example: (..., bool flag=false)
# args_define : "str" // str of funtion params without default value. Example: (..., bool flag)
# args_declare : "str" // str of function params with default value. Example: (..., bool flag=false)
# args_define : "str" // str of function params without default value. Example: (..., bool flag)
self.inputs, self.attrs, self.outputs, self.args_str = self.parse_args(
self.api, api_item_yaml)
......@@ -43,32 +43,11 @@ class BaseAPI(object):
self.is_base_api = False
self.invoke = api_item_yaml['invoke']
else:
self.kernel = api_item_yaml['kernel']
if 'backend' not in self.kernel or len(self.kernel['backend']) == 0:
self.kernel['backend'] = None
if 'layout' not in self.kernel or len(self.kernel['layout']) == 0:
self.kernel['layout'] = None
if 'data_type' not in self.kernel or len(self.kernel[
'data_type']) == 0:
self.kernel['data_type'] = None
if 'param' not in self.kernel:
self.kernel['param'] = None
self.infer_meta = api_item_yaml['infer_meta']
if 'param' not in self.infer_meta:
self.infer_meta['param'] = None
self.data_transform = {
'skip_transform': [],
'support_trans_dtype': []
}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
self.data_transform['skip_transform'] = api_item_yaml[
'data_transform']['skip_transform']
if 'support_trans_dtype' in api_item_yaml['data_transform']:
self.data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform']['support_trans_dtype']
self.infer_meta = self.parse_infer_meta(api_item_yaml['infer_meta'])
self.kernel = self.parse_kernel(api_item_yaml['kernel'])
self.support_selected_rows_kernel = False if len(self.kernel[
'func']) == 1 else True
self.data_transform = self.parse_data_transform(api_item_yaml)
def get_api_name(self, api_item_yaml):
return api_item_yaml['api']
......@@ -185,6 +164,61 @@ class BaseAPI(object):
return out_type_list, out_name_list, self.get_return_type(
out_type_list)
def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config
if 'param' not in infer_meta_config:
infer_meta['param'] = None
return infer_meta
def parse_kernel(self, kernel_config):
# kernel :
# func : [], Kernel functions (example: scale, scale_sr)
# param : [], Input params of kernel
# backend : str, the names of param to choose the kernel backend, default is None
# layout : str, the names of param to choose the kernel layout, default is None
# data_type : str, the names of param to choose the kernel data_type, default is None
kernel = {
'func': [],
'param': None,
'backend': None,
'layout': None,
'data_type': None
}
if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
kernel['backend'] = kernel_config['backend']
if 'layout' in kernel_config and len(kernel_config['layout']) > 0:
kernel['layout'] = kernel_config['layout']
if 'data_type' in kernel_config and len(kernel_config['data_type']) > 0:
kernel['data_type'] = kernel_config['data_type']
if 'param' in kernel_config:
kernel['param'] = kernel_config['param']
kernel['func'] = [
kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',')
]
if len(kernel['func']) == 2:
assert kernel['func'][0] == self.api, \
f"{self.api} : Kernel func error: If kernel has two func config, the name of first func should be same with api name({self.api}), \
but now is {kernel['func'][0]}."
assert kernel['func'][1].endswith('_sr'), \
f"{self.api} : Kernel func error: If kernel has two func config, the name of second func should be a selected_rows kernel (the func name endwith '_sr'), \
but now is {kernel['func'][1]}."
return kernel
def parse_data_transform(self, api_item_yaml):
data_transform = {'skip_transform': [], 'support_trans_dtype': []}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
data_transform['skip_transform'] = api_item_yaml[
'data_transform']['skip_transform']
if 'support_trans_dtype' in api_item_yaml['data_transform']:
data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform']['support_trans_dtype']
return data_transform
# Override by child class
def get_return_type(self, out_type_list):
return None
......@@ -303,12 +337,18 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
kernel_select_code = kernel_key_item_init + kernel_select_code
if len(input_names) > 0:
if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f"""
KernelType kernel_type;
"""
kernel_select_code = kernel_select_code + f"""
if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
{'kernel_type = kernel_key_set.kernel_type;' if self.support_selected_rows_kernel else ''}
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
......@@ -320,15 +360,9 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
}}
}}"""
kernel_select_code = kernel_select_code + f"""
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"{kernel['func']}", {{kernel_backend, kernel_layout, kernel_data_type}});
VLOG(6) << "{api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
VLOG(6) << "{api} API kernel: " << kernel;"""
return kernel_select_code
def gene_infer_meta(self, kernel_output_names) -> str:
def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
input_names = self.inputs['names']
attr_names = self.attrs['names']
infer_meta = self.infer_meta
......@@ -343,11 +377,10 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
if param in input_names:
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
elif param in kernel_output_names:
meta_tensor_code = meta_tensor_code + " pten::MetaTensor " + param.replace(
PREFIX_TENSOR_NAME,
PREFIX_META_TENSOR_NAME) + "(" + param + ");\n"
meta_tensor_code = meta_tensor_code + code_indent + " pten::MetaTensor " + param.replace(
'kernel_', PREFIX_META_TENSOR_NAME) + "(" + param + ");\n"
param_code = param_code + "&" + param.replace(
PREFIX_TENSOR_NAME, PREFIX_META_TENSOR_NAME) + ", "
'kernel_', PREFIX_META_TENSOR_NAME) + ", "
elif param in attr_names:
param_code = param_code + param + ", "
elif isinstance(param, str):
......@@ -359,10 +392,10 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
param_code = param_code[:-2]
return f"""{meta_tensor_code}
pten::{infer_meta['func']}({param_code});
{code_indent} pten::{infer_meta['func']}({param_code});
"""
def get_kernel_args(self):
def get_kernel_args(self, code_indent):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
......@@ -394,11 +427,69 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
elif input_name in self.data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
else:
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[
param]])
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in self.attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::ScalarArray&')
param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in self.attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::Scalar&')
param = 'pten::Scalar(' + param + ')'
else:
kernel_args_type_list.append(self.attrs['attr_info'][param][
0])
kernel_args = kernel_args + param + ", "
elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", "
else:
kernel_args = kernel_args + str(param) + ", "
for out_type in self.outputs['types']:
kernel_args_type_list.append(out_trans_map[out_type])
kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"
return input_tensor_code, kernel_args[:-2], kernel_signature
def get_selected_rows_kernel_args(self, code_indent):
input_trans_map = {
'const Tensor&': 'const pten::SelectedRows&',
'const Tensor &': 'const pten::SelectedRows&'
}
out_trans_map = {'Tensor': 'pten::SelectedRows*'}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
......@@ -431,30 +522,77 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
return input_tensor_code, kernel_args[:-2], kernel_signature
# Override by child class
def gene_output(self, output_type_list):
def gene_output(self, output_type_list, set_out_func, code_indent):
return None, None, None
def gen_dense_tensor_kernel_code(self, code_indent):
input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
code_indent)
outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetKernelOutput', code_indent)
return f"""
{code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} VLOG(6) << "{self.api} API kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors}
{output_create}
{self.gene_infer_meta(kernel_output_names, code_indent)}
{code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} return out;"""
def gen_selected_rows_kernel_code(self, code_indent):
input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args(
code_indent)
outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent)
return f"""
{code_indent} auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][1]}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} VLOG(6) << "{self.api} API SelectedRows kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} VLOG(6) << "{self.api} API SelectedRows kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors}
{output_create}
{self.gene_infer_meta(kernel_output_names, code_indent)}
{code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} return out;"""
def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
)
outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'])
return f"""
api_code = f"""
PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str["args_define"]}) {{
{self.gene_kernel_select()}
"""
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors}
{output_create}
{self.gene_infer_meta(kernel_output_names)}
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
if self.support_selected_rows_kernel:
code_indent = ' '
api_code = api_code + f"""
if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{
{self.gen_dense_tensor_kernel_code(code_indent)}
}} else {{
{self.gen_selected_rows_kernel_code(code_indent)}
}}
}}
"""
return api_code
else:
code_indent = ''
return api_code + self.gen_dense_tensor_kernel_code(
code_indent) + """
}
"""
else:
......
......@@ -30,27 +30,27 @@ class ForwardAPI(BaseAPI):
out_type_list) == 1 else "std::tuple<" + ",".join(
out_type_list) + ">"
def gene_output(self, output_type_list):
def gene_output(self, output_type_list, set_out_func, code_indent):
kernel_output = ""
output_names = []
output_create = ""
if len(output_type_list) == 1:
kernel_output = 'dense_out'
output_names.append('dense_out')
kernel_output = 'kernel_out'
output_names.append('kernel_out')
output_create = f"""
{self.outputs['return_type']} out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
{code_indent} {self.outputs['return_type']} out;
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);"""
elif len(output_type_list) > 1:
output_create = f"""
{self.outputs['return_type']} out;"""
{code_indent} {self.outputs['return_type']} out;"""
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'dense_out_{i}, '
output_names.append(f'dense_out_{i}')
kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}')
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(kernel_backend, &std::get<{i}>(out));"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));"""
kernel_output = kernel_output[:-2]
else:
......
......@@ -62,41 +62,41 @@ class BackwardAPI(BaseAPI):
# check the output of backward
assert len(self.outputs['types']) <= len(fw_inputs['names']), \
f"{self.api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \
f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
Please check the output of {self.api} in yaml."
def get_return_type(self, out_type_list):
return out_type_list[0] if len(
out_type_list) == 1 else "std::vector<std::vector<Tensor>>"
def gene_output(self, output_type_list):
def gene_output(self, output_type_list, set_out_func, code_indent):
kernel_output = ""
output_names = []
output_create = ""
if len(output_type_list) == 1:
kernel_output = 'dense_out'
output_names.append('dense_out')
kernel_output = 'kernel_out'
output_names.append('kernel_out')
output_create = f"""
{self.outputs['return_type']} out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
{code_indent} {self.outputs['return_type']} out;
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);"""
elif len(output_type_list) > 1:
output_create = f"""
{self.outputs['return_type']} out({len(output_type_list)});"""
{code_indent} {self.outputs['return_type']} out({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
output_names.append(f'dense_out_{i}')
kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}')
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
output_create = output_create + f"""
out[{i}].emplace_back();"""
{code_indent} out[{i}].emplace_back();"""
else:
get_out_code = f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(kernel_backend, {get_out_code});"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
kernel_output = kernel_output[:-2]
else:
......
......@@ -26,7 +26,7 @@ def get_wrapped_infermeta_name(api_name):
def gene_wrapped_infermeta_and_register(api):
if api.is_base_api:
register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});"""
if api.infer_meta['param'] is not None:
tensor_type_map = {
......@@ -67,7 +67,7 @@ void {wrapped_infermeta_name}({", ".join(args)}) {{
"""
register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_name(api.kernel['func'][0])});"""
return declare_code, defind_code, register_code
else:
......@@ -80,11 +80,11 @@ def gene_infermeta_register(api):
if api.is_base_api:
if api.infer_meta['param'] is None:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});"""
else:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});"""
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{get_wrapped_infermeta_name(api.kernel['func'][0])});"""
else:
return ''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册