未验证 提交 aa45f931 编写于 作者: Z zyfncg 提交者: GitHub

Support code auto-gene for optimizer api in yaml (#43915)

* support complexd selected_rows kernel in yaml

* support configuring optimizer api in yaml

* fix data transform bug
上级 78023658
...@@ -224,7 +224,8 @@ add_custom_command( ...@@ -224,7 +224,8 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp}
${api_source_file} ${api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}" COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base} DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${api_gen_file}
${api_gen_base}
VERBATIM) VERBATIM)
# generate backward api # generate backward api
...@@ -240,7 +241,8 @@ add_custom_command( ...@@ -240,7 +241,8 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_source_file_tmp}
${bw_api_source_file} ${bw_api_source_file}
COMMENT "copy_if_different ${bw_api_header_file} ${bw_api_source_file}" COMMENT "copy_if_different ${bw_api_header_file} ${bw_api_source_file}"
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base} DEPENDS ${bw_api_yaml_file} ${legacy_bw_api_yaml_file} ${bw_api_gen_file}
${api_gen_base}
VERBATIM) VERBATIM)
# generate sparse api # generate sparse api
......
...@@ -32,237 +32,6 @@ limitations under the License. */ ...@@ -32,237 +32,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& learning_rate,
const Tensor& moment1,
const Tensor& moment2,
const Tensor& beta1_pow,
const Tensor& beta2_pow,
const paddle::optional<Tensor>& master_param,
const paddle::optional<Tensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "adam";
if (!phi::DenseTensor::classof(grad.impl().get())) {
kernel_name = "adam_dense_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {});
auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {});
auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {});
auto input_beta1_pow = PrepareData(beta1_pow, kernel.InputAt(5), {});
auto input_beta2_pow = PrepareData(beta2_pow, kernel.InputAt(6), {});
auto input_master_param = PrepareData(master_param, kernel.InputAt(7), {});
auto input_skip_update = PrepareData(skip_update, kernel.InputAt(8), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = input_param.get();
auto kernel_out_1 = input_moment1.get();
auto kernel_out_2 = input_moment2.get();
auto kernel_out_3 = input_beta1_pow.get();
auto kernel_out_4 = input_beta2_pow.get();
phi::DenseTensor* kernel_out_5 = nullptr;
if (input_master_param) {
kernel_out_5 = input_master_param.get_ptr();
}
auto input_meta_ref_master_param = MakeMetaTensor(input_master_param);
auto input_meta_ref_skip_update = MakeMetaTensor(input_skip_update);
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MetaTensor meta_out_3(kernel_out_3);
phi::MetaTensor meta_out_4(kernel_out_4);
phi::MetaTensor meta_out_5(kernel_out_5);
if (phi::DenseTensor::classof(grad.impl().get())) {
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
phi::AdamInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
MakeMetaTensor(*input_moment1),
MakeMetaTensor(*input_moment2),
MakeMetaTensor(*input_beta1_pow),
MakeMetaTensor(*input_beta2_pow),
input_meta_ref_master_param,
input_meta_ref_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
const paddle::optional<phi::DenseTensor>&,
const Scalar&,
const Scalar&,
const Scalar&,
bool,
int64_t,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_lr,
*input_moment1,
*input_moment2,
*input_beta1_pow,
*input_beta2_pow,
input_master_param,
input_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
} else {
auto input_grad = TensorToSelectedRows(grad);
phi::AdamInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
MakeMetaTensor(*input_moment1),
MakeMetaTensor(*input_moment2),
MakeMetaTensor(*input_beta1_pow),
MakeMetaTensor(*input_beta2_pow),
input_meta_ref_master_param,
input_meta_ref_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
const paddle::optional<phi::DenseTensor>&,
const Scalar&,
const Scalar&,
const Scalar&,
bool,
int64_t,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_lr,
*input_moment1,
*input_moment2,
*input_beta1_pow,
*input_beta2_pow,
input_master_param,
input_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
}
return api_output;
}
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl( std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
...@@ -1100,159 +869,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -1100,159 +869,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output; return api_output;
} }
std::tuple<Tensor, Tensor> sgd_impl(
const Tensor& param,
const Tensor& learning_rate,
const Tensor& grad,
const paddle::optional<Tensor>& master_param,
bool multi_precision) {
DataType kernel_data_type = ParseDataType(param);
auto kernel_key_set = ParseKernelKeyByInputArgs(param, learning_rate, grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
VLOG(6) << "sgd API kernel key: [" << kernel_key.backend() << ", "
<< kernel_key.layout() << ", " << kernel_data_type << "]";
const auto& param_tensor = param.impl();
std::string kernel_name = "sgd";
if (phi::DenseTensor::classof(param_tensor.get())) {
if (!phi::DenseTensor::classof(grad.impl().get())) {
kernel_name = "sgd_dense_param_sparse_grad";
}
} else {
kernel_name = "sgd_sparse_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name,
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto in_learning_rate =
PrepareData(learning_rate, kernel.InputAt(1), {false, true, true, true});
std::tuple<Tensor, Tensor> out;
std::get<0>(out) = param;
if (master_param) {
std::get<1>(out) = *master_param;
}
phi::MetaTensor meta_out_0(std::get<0>(out).impl().get());
phi::MetaTensor meta_out_1(master_param ? std::get<1>(out).impl().get()
: nullptr);
if (phi::DenseTensor::classof(param_tensor.get())) {
auto in_param = PrepareData(param, kernel.InputAt(0), {});
auto in_master_param_opt = PrepareData(master_param, kernel.InputAt(3), {});
auto master_param_meta_opt = MakeMetaTensor(in_master_param_opt);
phi::DenseTensor* kernel_out_0 =
SetKernelOutput(kernel_key.backend(), &std::get<0>(out));
phi::DenseTensor* kernel_out_1 =
master_param
? static_cast<phi::DenseTensor*>(std::get<1>(out).impl().get())
: nullptr;
if (phi::DenseTensor::classof(grad.impl().get())) {
auto in_grad = PrepareData(grad, kernel.InputAt(2), {});
SgdInferMeta(MakeMetaTensor(*in_param),
MakeMetaTensor(*in_learning_rate),
MakeMetaTensor(*in_grad),
master_param_meta_opt,
multi_precision,
&meta_out_0,
&meta_out_1);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
bool,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*in_param,
*in_learning_rate,
*in_grad,
in_master_param_opt,
multi_precision,
kernel_out_0,
kernel_out_1);
} else {
auto in_grad = TensorToSelectedRows(grad);
SgdInferMeta(MakeMetaTensor(*in_param),
MakeMetaTensor(*in_learning_rate),
MakeMetaTensor(*in_grad),
master_param_meta_opt,
multi_precision,
&meta_out_0,
&meta_out_1);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const paddle::optional<phi::DenseTensor>&,
bool,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*in_param,
*in_learning_rate,
*in_grad,
in_master_param_opt,
multi_precision,
kernel_out_0,
kernel_out_1);
}
} else {
auto in_param = TensorToSelectedRows(param);
auto in_grad = TensorToSelectedRows(grad);
auto in_master_param_opt = TensorToSelectedRows(master_param);
auto master_param_meta = MakeMetaTensor(in_master_param_opt);
phi::SelectedRows* kernel_out_0 =
SetSelectedRowsKernelOutput(kernel_key.backend(), &std::get<0>(out));
phi::SelectedRows* kernel_out_1 =
master_param
? static_cast<phi::SelectedRows*>(std::get<1>(out).impl().get())
: nullptr;
SgdInferMeta(MakeMetaTensor(*in_param),
MakeMetaTensor(*in_learning_rate),
MakeMetaTensor(*in_grad),
master_param_meta,
multi_precision,
&meta_out_0,
&meta_out_1);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const phi::SelectedRows&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const paddle::optional<phi::SelectedRows>&,
bool,
phi::SelectedRows*,
phi::SelectedRows*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*in_param,
*in_learning_rate,
*in_grad,
in_master_param_opt,
multi_precision,
kernel_out_0,
kernel_out_1);
}
return out;
}
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
// TODO(chenweihang): the original sum grad op can support higher-level // TODO(chenweihang): the original sum grad op can support higher-level
......
...@@ -31,24 +31,6 @@ namespace experimental { ...@@ -31,24 +31,6 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& learning_rate,
const Tensor& moment1,
const Tensor& moment2,
const Tensor& beta1_pow,
const Tensor& beta2_pow,
const paddle::optional<Tensor>& master_param,
const paddle::optional<Tensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl( std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
const Tensor& param, const Tensor& param,
const Tensor& grad, const Tensor& grad,
...@@ -132,13 +114,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -132,13 +114,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision, bool multi_precision,
float rescale_grad); float rescale_grad);
std::tuple<Tensor, Tensor> sgd_impl(
const Tensor& param,
const Tensor& learning_rate,
const Tensor& grad,
const paddle::optional<Tensor>& master_param,
bool multi_precision);
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
void add_n_grad_impl(const std::vector<Tensor>& x, void add_n_grad_impl(const std::vector<Tensor>& x,
......
...@@ -62,7 +62,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) { ...@@ -62,7 +62,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) {
/* ----------------- for infer_meta --------------------- */ /* ----------------- for infer_meta --------------------- */
phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) { phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
return phi::MetaTensor(tensor); return phi::MetaTensor(tensor);
} }
...@@ -94,10 +94,6 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -94,10 +94,6 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors; return meta_tensors;
} }
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor);
}
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor) { const paddle::optional<phi::SelectedRows>& tensor) {
if (tensor) { if (tensor) {
...@@ -106,10 +102,6 @@ phi::MetaTensor MakeMetaTensor( ...@@ -106,10 +102,6 @@ phi::MetaTensor MakeMetaTensor(
return phi::MetaTensor(); return phi::MetaTensor();
} }
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor) {
return phi::MetaTensor(tensor);
}
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
......
...@@ -47,7 +47,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor); ...@@ -47,7 +47,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor);
/* ----------------- for infer_meta --------------------- */ /* ----------------- for infer_meta --------------------- */
phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor); phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor);
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::DenseTensor>& tensor); const paddle::optional<phi::DenseTensor>& tensor);
...@@ -58,13 +58,9 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -58,13 +58,9 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
std::vector<phi::MetaTensor> MakeMetaTensor( std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors); const std::vector<phi::DenseTensor*>& tensors);
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor);
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor); const paddle::optional<phi::SelectedRows>& tensor);
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor);
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out); phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out);
......
...@@ -1370,8 +1370,8 @@ class SGDOptimizer(Optimizer): ...@@ -1370,8 +1370,8 @@ class SGDOptimizer(Optimizer):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if in_dygraph_mode(): if in_dygraph_mode():
_C_ops.final_state_sgd(param_and_grad[0], lr, param_and_grad[1], _C_ops.final_state_sgd_(param_and_grad[0], lr, param_and_grad[1],
master_weight, find_master) master_weight, find_master)
return None return None
if _in_legacy_dygraph(): if _in_legacy_dygraph():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight, _C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
......
...@@ -342,7 +342,7 @@ class Adam(Optimizer): ...@@ -342,7 +342,7 @@ class Adam(Optimizer):
_beta2 = self._beta2 if not isinstance( _beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0) self._beta2, Variable) else self._beta2.numpy().item(0)
_, _, _, _, _, _ = _C_ops.final_state_adam( _, _, _, _, _, _ = _C_ops.final_state_adam_(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, _beta1, beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, _beta1,
_beta2, self._epsilon, self._lazy_mode, 1000, find_master, _beta2, self._epsilon, self._lazy_mode, 1000, find_master,
......
...@@ -143,8 +143,8 @@ class SGD(Optimizer): ...@@ -143,8 +143,8 @@ class SGD(Optimizer):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if in_dygraph_mode(): if in_dygraph_mode():
_C_ops.final_state_sgd(param_and_grad[0], lr, param_and_grad[1], _C_ops.final_state_sgd_(param_and_grad[0], lr, param_and_grad[1],
master_weight, find_master) master_weight, find_master)
return None return None
if _in_legacy_dygraph(): if _in_legacy_dygraph():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight, _C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
......
...@@ -45,9 +45,6 @@ class BaseAPI(object): ...@@ -45,9 +45,6 @@ class BaseAPI(object):
self.infer_meta = self.parse_infer_meta( self.infer_meta = self.parse_infer_meta(
api_item_yaml['infer_meta']) api_item_yaml['infer_meta'])
self.kernel = self.parse_kernel(api_item_yaml['kernel']) self.kernel = self.parse_kernel(api_item_yaml['kernel'])
self.support_selected_rows_kernel = False if len(
self.kernel['func']
) == 1 or not self.kernel['func'][1].endswith('_sr') else True
self.data_transform = self.parse_data_transform(api_item_yaml) self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map, self.view_map = {}, {} self.inplace_map, self.view_map = {}, {}
...@@ -61,6 +58,7 @@ class BaseAPI(object): ...@@ -61,6 +58,7 @@ class BaseAPI(object):
input_args = [] input_args = []
inplace_type_map = { inplace_type_map = {
"const Tensor&": "Tensor&", "const Tensor&": "Tensor&",
"const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
"const std::vector<Tensor>&": "std::vector<Tensor>&" "const std::vector<Tensor>&": "std::vector<Tensor>&"
} }
for name in self.inputs['names']: for name in self.inputs['names']:
...@@ -285,6 +283,17 @@ class BaseAPI(object): ...@@ -285,6 +283,17 @@ class BaseAPI(object):
tmp_in_out_list = in_out_str[1:-1].split('->') tmp_in_out_list = in_out_str[1:-1].split('->')
inputs = [item.strip() for item in tmp_in_out_list[0].split(',')] inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
outputs = [item.strip() for item in tmp_in_out_list[1].split(',')] outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
# check the tensor type
for item in inputs:
assert item in [
'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
], f"{self.api} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
for item in outputs:
assert item in [
'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
], f"{self.api} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
return (inputs, outputs) return (inputs, outputs)
for func_item in kernel_funcs: for func_item in kernel_funcs:
...@@ -440,11 +449,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -440,11 +449,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
kernel_select_code = kernel_key_item_init + kernel_select_code kernel_select_code = kernel_key_item_init + kernel_select_code
if len(input_names) > 0: if len(input_names) > 0:
if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f"""
KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)});
"""
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
if (kernel_backend == Backend::UNDEFINED if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED || kernel_layout == DataLayout::UNDEFINED
...@@ -528,8 +532,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -528,8 +532,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} phi::{infer_meta['func']}({param_code}); {code_indent} phi::{infer_meta['func']}({param_code});
""" """
def get_kernel_args(self, code_indent): def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
input_trans_map = { dense_input_trans_map = {
'const Tensor&': 'const Tensor&':
'const phi::DenseTensor&', 'const phi::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<Tensor>&':
...@@ -541,10 +545,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -541,10 +545,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
'const paddle::optional<std::vector<Tensor>>&': 'const paddle::optional<std::vector<Tensor>>&':
'paddle::optional<const std::vector<phi::DenseTensor>&>' 'paddle::optional<const std::vector<phi::DenseTensor>&>'
} }
out_trans_map = { dense_out_trans_map = {
'Tensor': 'phi::DenseTensor*', 'Tensor': 'phi::DenseTensor*',
'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&' 'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&'
} }
sr_input_trans_map = {
'const Tensor&':
'const phi::SelectedRows&',
'const paddle::optional<Tensor>&':
'const paddle::optional<phi::SelectedRows>&'
}
sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
input_names = self.inputs['names'] input_names = self.inputs['names']
input_infos = self.inputs['input_info'] input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&'] kernel_args_type_list = ['const platform::DeviceContext&']
...@@ -558,127 +569,72 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -558,127 +569,72 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
for i, input_name in enumerate(input_names): for i, input_name in enumerate(input_names):
# set input code # set input code
if input_name in kernel_param: if input_name in kernel_param:
trans_flag = "{}" # input is dense tensor
if input_name in self.data_transform['skip_transform']: if kernel_tensor_type is None or kernel_tensor_type[0][
trans_flag = "{true}" kernel_param.index(input_name)] == 'dense':
elif input_name in self.data_transform['support_trans_dtype']: trans_flag = "{}"
trans_flag = "{false, true}" if input_name in self.data_transform['skip_transform']:
if input_name in self.optional_vars: trans_flag = "{true}"
input_tensor_code = input_tensor_code + f""" elif input_name in self.data_transform[
'support_trans_dtype']:
trans_flag = "{false, true}"
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
else: else:
if self.inputs['input_info'][input_name] == "const Tensor&": if self.inputs['input_info'][
input_tensor_code = input_tensor_code + f""" input_name] == "const Tensor&":
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
elif self.inputs['input_info'][ elif self.inputs['input_info'][
input_name] == "const std::vector<Tensor>&": input_name] == "const std::vector<Tensor>&":
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag}); {code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});
{code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size()); {code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ {code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); {code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent} }}""" {code_indent} }}"""
else: else:
# do nothing # do nothing
pass pass
else: else: # input is selected_rows
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} {input_trans_map[input_infos[input_name]]} {PREFIX_TENSOR_NAME}{input_name}(paddle::none); {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_ptr = TensorToDenseTensor({input_name}); else:
{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_ptr) {{ if input_name in self.infer_meta['param']:
{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::make_optional<const phi::DenseTensor&>(*{PREFIX_TENSOR_NAME}{input_name}_ptr); if input_name in self.optional_vars:
{code_indent} }}""" input_tensor_code = input_tensor_code + f"""
{code_indent} paddle::optional<phi::TensorBase> {PREFIX_TENSOR_NAME}{input_name} = {input_name} ? paddle::optional<phi::TensorBase>(*{input_name}->impl()) : paddle::none;"""
else: else:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
kernel_args = "*dev_ctx, " kernel_args = ["*dev_ctx"]
for param in kernel_param: for param in kernel_param:
if param in input_names: if param in input_names:
if param in self.optional_vars: if param in self.optional_vars:
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", " kernel_args.append(PREFIX_TENSOR_NAME + param)
else: else:
if self.inputs['input_info'][param] == "const Tensor&": if self.inputs['input_info'][param] == "const Tensor&":
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", " kernel_args.append("*" + PREFIX_TENSOR_NAME + param)
elif self.inputs['input_info'][ elif self.inputs['input_info'][
param] == "const std::vector<Tensor>&": param] == "const std::vector<Tensor>&":
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", " kernel_args.append(PREFIX_TENSOR_NAME + param)
else: else:
# do nothing # do nothing
pass pass
kernel_in_type = input_trans_map[input_infos[param]] # input is dense tensor
kernel_args_type_list.append(kernel_in_type) if kernel_tensor_type is None or kernel_tensor_type[0][
elif param in attr_names: kernel_param.index(param)] == 'dense':
# set attr for kernel_context
if 'IntArray' in self.attrs['attr_info'][param][0]:
kernel_args_type_list.append('const phi::IntArray&')
param = 'phi::IntArray(' + param + ')'
elif 'Scalar' in self.attrs['attr_info'][param][0]:
kernel_args_type_list.append('const phi::Scalar&')
param = 'phi::Scalar(' + param + ')'
else:
kernel_args_type_list.append( kernel_args_type_list.append(
self.attrs['attr_info'][param][0]) dense_input_trans_map[input_infos[param]])
kernel_args = kernel_args + param + ", " else: # input is selected_rows
elif isinstance(param, bool): kernel_args_type_list.append(
kernel_args = kernel_args + str(param).lower() + ", " sr_input_trans_map[input_infos[param]])
else:
kernel_args = kernel_args + str(param) + ", "
for out_type in self.outputs['types']:
kernel_args_type_list.append(out_trans_map[out_type])
kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"
return input_tensor_code, kernel_args[:-2], kernel_signature
def get_selected_rows_kernel_args(self, code_indent):
input_trans_map = {
'const Tensor&':
'const phi::SelectedRows&',
'const paddle::optional<Tensor>&':
'const paddle::optional<phi::SelectedRows>&'
}
out_trans_map = {'Tensor': 'phi::SelectedRows*'}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {input_trans_map[input_infos[input_name]]} {PREFIX_TENSOR_NAME}{input_name}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_ptr = TensorToSelectedRows({input_name});
{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_ptr) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::make_optional<const phi::SelectedRows&>(*{PREFIX_TENSOR_NAME}{input_name}_ptr);
{code_indent} }}"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
if param in input_names:
if param in self.optional_vars:
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", "
else:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_in_type = input_trans_map[input_infos[param]]
kernel_args_type_list.append(kernel_in_type)
elif param in attr_names: elif param in attr_names:
# set attr for kernel_context # set attr for kernel_context
if 'IntArray' in self.attrs['attr_info'][param][0]: if 'IntArray' in self.attrs['attr_info'][param][0]:
...@@ -690,18 +646,22 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -690,18 +646,22 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
else: else:
kernel_args_type_list.append( kernel_args_type_list.append(
self.attrs['attr_info'][param][0]) self.attrs['attr_info'][param][0])
kernel_args = kernel_args + param + ", " kernel_args.append(param)
elif isinstance(param, bool): elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", " kernel_args.append(str(param).lower())
else: else:
kernel_args = kernel_args + str(param) + ", " kernel_args.append(str(param))
for out_type in self.outputs['types']: for i, out_type in enumerate(self.outputs['types']):
kernel_args_type_list.append(out_trans_map[out_type]) # output is dense tensor
if kernel_tensor_type is None or kernel_tensor_type[1][i] == 'dense':
kernel_args_type_list.append(dense_out_trans_map[out_type])
else: # output is selected_rows
kernel_args_type_list.append(sr_out_trans_map[out_type])
kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")" kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"
return input_tensor_code, kernel_args[:-2], kernel_signature return input_tensor_code, ", ".join(kernel_args), kernel_signature
# Override by child class # Override by child class
def gene_return_code(self): def gene_return_code(self):
...@@ -709,25 +669,27 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -709,25 +669,27 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
# Override by child class # Override by child class
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
return None, None, None return None, None, None
def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False): def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
kernel_dispatch = self.kernel['dispatch'][kernel_name]
input_tensors, kernel_args, kernel_signature = self.get_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
code_indent) kernel_dispatch, code_indent)
out_tensor_type_list = kernel_dispatch[1] if kernel_dispatch else None
outputs_args, kernel_output_names, output_create = self.gene_output( outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag) self.outputs['types'], out_tensor_type_list, code_indent,
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') inplace_flag)
cudnn_args = '' if self.kernel[ cudnn_args = '' if self.kernel[
'use_gpudnn'] == 'false' else ', ' + self.kernel['use_gpudnn'] 'use_gpudnn'] == 'false' else ', ' + self.kernel['use_gpudnn']
return f""" return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; {code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args}); {code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
{code_indent} VLOG(6) << "{self.api} API kernel: " << kernel; {code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); {code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors} {input_tensors}
...@@ -737,38 +699,42 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -737,38 +699,42 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} using kernel_signature = {kernel_signature}; {code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} {{ {code_indent} {{
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1); {code_indent} paddle::platform::RecordEvent kernel_record_event(\"{kernel_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); {code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} }} {code_indent} }}
{code_indent} {self.gene_return_code()}""" {code_indent} {self.gene_return_code()}"""
def gen_selected_rows_kernel_code(self, code_indent, inplace_flag=False): def get_condition_code(self, kernel_name):
input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args( assert self.kernel['dispatch'][kernel_name], \
code_indent) f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'scale' in api.yaml."
outputs_args, kernel_output_names, output_create = self.gene_output( input_types = self.kernel['dispatch'][kernel_name][0]
self.outputs['types'], 'SetSelectedRowsKernelOutput', code_indent, condition_list = []
inplace_flag) for i, in_type in enumerate(input_types):
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') if in_type == "dense":
return f""" if self.inputs['names'][i] in self.optional_vars:
{code_indent} auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( condition_list.append(
{code_indent} "{self.kernel['func'][1]}", {{kernel_backend, kernel_layout, kernel_data_type}}); f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_dense_tensor())"
{code_indent} VLOG(6) << "{self.api} API SelectedRows kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; )
{code_indent} VLOG(6) << "{self.api} API SelectedRows kernel: " << kernel; else:
condition_list.append(
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); f"{self.inputs['names'][i]}.is_dense_tensor()")
{input_tensors} else:
{output_create} if self.inputs['names'][i] in self.optional_vars:
{self.gene_infer_meta(kernel_output_names, code_indent)} condition_list.append(
f"(!{self.inputs['names'][i]} || {self.inputs['names'][i]}->is_selected_rows())"
{code_indent} using kernel_signature = {kernel_signature}; )
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); else:
{code_indent} {{ condition_list.append(
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1); f"{self.inputs['names'][i]}.is_selected_rows()")
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); return " && ".join(condition_list)
{code_indent} }}
{code_indent} {self.gene_return_code()}""" def gene_dispatch_code(self, kernel_name, inplace_flag=False):
return f"""
if ({self.get_condition_code(kernel_name)}) {{
{self.gen_kernel_code(kernel_name, ' ', inplace_flag)}
}}
"""
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() api_func_name = self.get_api_func_name()
...@@ -779,21 +745,20 @@ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define ...@@ -779,21 +745,20 @@ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define
{self.gene_kernel_select()} {self.gene_kernel_select()}
""" """
if self.support_selected_rows_kernel: if len(self.kernel['func']) > 1:
code_indent = ' ' kernel_dispatch_code = ''
for kernel_name in self.kernel['func']:
kernel_dispatch_code += self.gene_dispatch_code(
kernel_name, inplace_flag)
return api_code + f""" return api_code + f"""
if(kernel_type == KernelType::DENSE_TENSOR_KENREL){{ {kernel_dispatch_code}
{self.gen_dense_tensor_kernel_code(code_indent, inplace_flag)} PADDLE_THROW(phi::errors::Unimplemented(
}} else {{ "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
{self.gen_selected_rows_kernel_code(code_indent, inplace_flag)}
}}
}} }}
""" """
else: else:
code_indent = '' return api_code + self.gen_kernel_code(self.kernel['func'][0], '',
return api_code + self.gen_dense_tensor_kernel_code( inplace_flag) + """
code_indent, inplace_flag) + """
} }
""" """
......
...@@ -24,6 +24,11 @@ inplace_out_type_map = { ...@@ -24,6 +24,11 @@ inplace_out_type_map = {
"std::vector<Tensor>": "std::vector<Tensor>&" "std::vector<Tensor>": "std::vector<Tensor>&"
} }
inplace_optional_out_type_map = {
"Tensor": "paddle::optional<Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&"
}
class ForwardAPI(BaseAPI): class ForwardAPI(BaseAPI):
...@@ -80,7 +85,11 @@ class ForwardAPI(BaseAPI): ...@@ -80,7 +85,11 @@ class ForwardAPI(BaseAPI):
for i, out_type in enumerate(self.outputs['types']): for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0] out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map: if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type]) if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type])
else:
out_type_list.append(inplace_out_type_map[out_type])
else: else:
out_type_list.append(out_type) out_type_list.append(out_type)
...@@ -94,7 +103,11 @@ class ForwardAPI(BaseAPI): ...@@ -94,7 +103,11 @@ class ForwardAPI(BaseAPI):
for i, out_type in enumerate(self.outputs['types']): for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0] out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map: if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type]) if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type])
else:
out_type_list.append(inplace_out_type_map[out_type])
elif self.is_dygraph_api or out_name not in self.intermediate_outs: elif self.is_dygraph_api or out_name not in self.intermediate_outs:
out_type_list.append(out_type) out_type_list.append(out_type)
...@@ -120,16 +133,16 @@ class ForwardAPI(BaseAPI): ...@@ -120,16 +133,16 @@ class ForwardAPI(BaseAPI):
return 'return {' + ", ".join(selected_code) + '};' return 'return {' + ", ".join(selected_code) + '};'
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag) return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[ inplace_assign = " = " + self.inplace_map[
...@@ -137,7 +150,8 @@ class ForwardAPI(BaseAPI): ...@@ -137,7 +150,8 @@ class ForwardAPI(BaseAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {return_type} api_output{inplace_assign};""" {code_indent} {return_type} api_output{inplace_assign};"""
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
0] == 'dense' else 'SetSelectedRowsKernelOutput'
if return_type == 'std::vector<Tensor>': if return_type == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'][0] is not None, \ assert self.outputs['out_size_expr'][0] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
...@@ -155,7 +169,7 @@ class ForwardAPI(BaseAPI): ...@@ -155,7 +169,7 @@ class ForwardAPI(BaseAPI):
{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]}); {code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" {code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
{code_indent} {return_type} api_output;""" {code_indent} {return_type} api_output;"""
...@@ -171,19 +185,27 @@ class ForwardAPI(BaseAPI): ...@@ -171,19 +185,27 @@ class ForwardAPI(BaseAPI):
output_create += 'Tensor(), ' output_create += 'Tensor(), '
output_create = output_create[:-2] + '};' output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)): for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
i] == 'dense' else 'SetSelectedRowsKernelOutput'
get_out_code = f"&std::get<{i}>(api_output)"
if self.outputs['names'][
i] in self.inplace_map and self.inplace_map[
self.outputs['names'][i]] in self.optional_vars:
get_out_code = f"std::get<{i}>(api_output).get_ptr()"
if output_type_list[i] == 'std::vector<Tensor>': if out_dtype_list[i] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'][i] is not None, \ assert self.outputs['out_size_expr'][i] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, &std::get<{i}>(api_output));""" {code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, {get_out_code});"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
if not inplace_flag and self.view_map is not None and self.outputs[ if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map: 'names'][i] in self.view_map:
......
...@@ -114,22 +114,24 @@ class BackwardAPI(BaseAPI): ...@@ -114,22 +114,24 @@ class BackwardAPI(BaseAPI):
return 'void' return 'void'
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = "" output_create = ""
if output_type_list[0] == 'std::vector<Tensor>': set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
0] == 'dense' else 'SetSelectedRowsKernelOutput'
if out_dtype_list[0] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \ assert self.outputs['out_size_expr'] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
...@@ -139,11 +141,13 @@ class BackwardAPI(BaseAPI): ...@@ -139,11 +141,13 @@ class BackwardAPI(BaseAPI):
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = "" output_create = ""
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(out_dtype_list):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
i] == 'dense' else 'SetSelectedRowsKernelOutput'
if out_type_item == 'Tensor': if out_type_item == 'Tensor':
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
......
...@@ -48,11 +48,17 @@ ...@@ -48,11 +48,17 @@
kernel : kernel :
func : adadelta func : adadelta
- api : adam - api : adam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta :
func : AdamInferMeta
kernel :
func : adam {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense},
adam_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}
data_type : param
optional : master_param, skip_update optional : master_param, skip_update
invoke : adam_impl(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, beta1, beta2, epsilon, lazy_mode, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow) inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs)
- api : adamax - api : adamax
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment, Tensor inf_norm, Tensor beta1_pow, float beta1, float beta2, float epsilon) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment, Tensor inf_norm, Tensor beta1_pow, float beta1, float beta2, float epsilon)
...@@ -1015,7 +1021,8 @@ ...@@ -1015,7 +1021,8 @@
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
func : isfinite, infinite_sr func : isfinite {dense -> dense},
infinite_sr {selected_rows -> selected_rows}
# isinf # isinf
- api : isinf - api : isinf
...@@ -1024,7 +1031,8 @@ ...@@ -1024,7 +1031,8 @@
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
func : isinf, isinf_sr func : isinf {dense -> dense},
isinf_sr {selected_rows -> selected_rows}
# isnan # isnan
- api : isnan - api : isnan
...@@ -1033,7 +1041,8 @@ ...@@ -1033,7 +1041,8 @@
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
func : isnan, isnan_sr func : isnan {dense -> dense},
isnan_sr {selected_rows -> selected_rows}
- api : kldiv_loss - api : kldiv_loss
args : (Tensor x, Tensor label, str reduction) args : (Tensor x, Tensor label, str reduction)
...@@ -1774,7 +1783,8 @@ ...@@ -1774,7 +1783,8 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : scale, scale_sr func : scale {dense -> dense},
scale_sr {selected_rows -> selected_rows}
inplace : (x -> out) inplace : (x -> out)
backward : scale_grad backward : scale_grad
...@@ -1829,11 +1839,20 @@ ...@@ -1829,11 +1839,20 @@
func : selu func : selu
backward : selu_grad backward : selu_grad
- api : sgd - api : sgd_
args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision) args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision)
output : Tensor(param_out), Tensor(master_param_out) output : Tensor(param_out), Tensor(master_param_out)
invoke : sgd_impl(param, learning_rate, grad, master_param, multi_precision) infer_meta :
func : SgdInferMeta
kernel :
func : sgd {dense, dense, dense, dense -> dense, dense},
sgd_dense_param_sparse_grad {dense, dense, selected_rows, dense -> dense, dense},
sgd_sparse_param_sparse_grad {selected_rows, dense, selected_rows, selected_rows -> selected_rows, selected_rows}
data_type : param
data_transform :
support_trans_dtype : learning_rate
optional : master_param optional : master_param
inplace : (param -> param_out), (master_param -> master_param_out)
- api : shape - api : shape
args : (Tensor input) args : (Tensor input)
...@@ -1841,7 +1860,8 @@ ...@@ -1841,7 +1860,8 @@
infer_meta : infer_meta :
func : ShapeInferMeta func : ShapeInferMeta
kernel : kernel :
func : shape, shape_sr func : shape {dense -> dense},
shape_sr {selected_rows -> selected_rows}
data_transform: data_transform:
skip_transform : input skip_transform : input
......
...@@ -31,18 +31,10 @@ class SparseAPI(ForwardAPI): ...@@ -31,18 +31,10 @@ class SparseAPI(ForwardAPI):
{super(SparseAPI, self).gene_api_declaration()} {super(SparseAPI, self).gene_api_declaration()}
""" """
def get_kernel_tensor_out_type(self, output_name):
sparse_type = 'TensorType::DENSE_TENSOR'
if output_name.endswith('@SparseCooTensor'):
sparse_type = 'TensorType::SPARSE_COO'
elif output_name.endswith('@SparseCsrTensor'):
sparse_type = 'TensorType::SPARSE_CSR'
return sparse_type
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
...@@ -54,7 +46,7 @@ class SparseAPI(ForwardAPI): ...@@ -54,7 +46,7 @@ class SparseAPI(ForwardAPI):
'sparse_csr': 'TensorType::SPARSE_CSR' 'sparse_csr': 'TensorType::SPARSE_CSR'
} }
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
...@@ -62,9 +54,9 @@ class SparseAPI(ForwardAPI): ...@@ -62,9 +54,9 @@ class SparseAPI(ForwardAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{return_type} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
auto* kernel_out = {set_out_func}(&api_output, {output_type_map[output_type_list[0]]});""" auto* kernel_out = SetSparseKernelOutput(&api_output, {output_type_map[out_dtype_list[0]]});"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
{return_type} api_output;""" {return_type} api_output;"""
...@@ -80,11 +72,11 @@ class SparseAPI(ForwardAPI): ...@@ -80,11 +72,11 @@ class SparseAPI(ForwardAPI):
output_create += 'Tensor(), ' output_create += 'Tensor(), '
output_create = output_create[:-2] + '};' output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)): for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
output_create = output_create + f""" output_create = output_create + f"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {output_type_map[output_type_list[i]]});""" auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -148,8 +140,7 @@ class SparseAPI(ForwardAPI): ...@@ -148,8 +140,7 @@ class SparseAPI(ForwardAPI):
def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False): def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
_, kernel_output_names, output_create = self.gene_output( _, kernel_output_names, output_create = self.gene_output(
self.kernel['dispatch'][kernel_name][1], 'SetSparseKernelOutput', self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag)
'', inplace_flag)
kernel_context_code = self.gen_sparse_kernel_context( kernel_context_code = self.gen_sparse_kernel_context(
kernel_output_names) kernel_output_names)
...@@ -189,7 +180,6 @@ class SparseAPI(ForwardAPI): ...@@ -189,7 +180,6 @@ class SparseAPI(ForwardAPI):
return " && ".join(condition_list) return " && ".join(condition_list)
def gene_dispatch_code(self, kernel_name, inplace_flag=False): def gene_dispatch_code(self, kernel_name, inplace_flag=False):
dispatch_code = ""
return f""" return f"""
if ({self.get_condition_code(kernel_name)}) {{ if ({self.get_condition_code(kernel_name)}) {{
{self.gen_sparse_kernel_code(kernel_name, inplace_flag)} {self.gen_sparse_kernel_code(kernel_name, inplace_flag)}
......
...@@ -48,9 +48,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -48,9 +48,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
return BackwardAPI.get_define_args(self) return BackwardAPI.get_define_args(self)
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
...@@ -61,19 +61,19 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -61,19 +61,19 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
'sparse_csr': 'TensorType::SPARSE_CSR' 'sparse_csr': 'TensorType::SPARSE_CSR'
} }
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
auto kernel_out = {set_out_func}({self.outputs['names'][0]}, {output_type_map[output_type_list[0]]});""" auto kernel_out = SetSparseKernelOutput({self.outputs['names'][0]}, {output_type_map[out_dtype_list[0]]});"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = "" output_create = ""
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(out_dtype_list):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
...@@ -82,7 +82,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -82,7 +82,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
*{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};""" *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i]}, {output_type_map[output_type_list[i]]});""" auto kernel_out_{i} = SetSparseKernelOutput({self.outputs['names'][i]}, {output_type_map[out_dtype_list[i]]});"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
......
...@@ -51,16 +51,16 @@ class StringsAPI(ForwardAPI): ...@@ -51,16 +51,16 @@ class StringsAPI(ForwardAPI):
return tensor_type_dict[kernel_tensor_out_type] return tensor_type_dict[kernel_tensor_out_type]
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type(inplace_flag) return_type = self.get_return_type(inplace_flag)
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
kernel_tensor_out_type = self.get_kernel_tensor_out_type( kernel_tensor_out_type = self.get_kernel_tensor_out_type(
...@@ -71,13 +71,13 @@ class StringsAPI(ForwardAPI): ...@@ -71,13 +71,13 @@ class StringsAPI(ForwardAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{return_type} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>({set_out_func}(kernel_backend, &api_output, {kernel_tensor_out_type}));""" {tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(kernel_backend, &api_output, {kernel_tensor_out_type}));"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
{return_type} api_output;""" {return_type} api_output;"""
for i in range(len(output_type_list)): for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
kernel_tensor_out_type = self.get_kernel_tensor_out_type( kernel_tensor_out_type = self.get_kernel_tensor_out_type(
...@@ -89,7 +89,7 @@ class StringsAPI(ForwardAPI): ...@@ -89,7 +89,7 @@ class StringsAPI(ForwardAPI):
std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};""" std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{tensor_type}* kernel_out_{i} = dynamic_cast<{tensor_type}*>({set_out_func}(&std::get<{i}>(api_output), {kernel_tensor_out_type}));""" {tensor_type}* kernel_out_{i} = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(&std::get<{i}>(api_output), {kernel_tensor_out_type}));"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -174,7 +174,7 @@ class StringsAPI(ForwardAPI): ...@@ -174,7 +174,7 @@ class StringsAPI(ForwardAPI):
input_tensors, kernel_args, kernel_signature = self.get_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
code_indent) code_indent)
outputs_args, kernel_output_names, output_create = self.gene_output( outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetStringsKernelOutput', '', inplace_flag) self.outputs['types'], None, '', inplace_flag)
return f""" return f"""
// 1. Get kernel signature and kernel // 1. Get kernel signature and kernel
...@@ -252,11 +252,6 @@ class StringsAPI(ForwardAPI): ...@@ -252,11 +252,6 @@ class StringsAPI(ForwardAPI):
kernel_select_code = kernel_key_item_init + kernel_select_code kernel_select_code = kernel_key_item_init + kernel_select_code
if len(input_names) > 0: if len(input_names) > 0:
if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f"""
KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)});
"""
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册