未验证 提交 aa45f931 编写于 作者: Z zyfncg 提交者: GitHub

Support code auto-gene for optimizer api in yaml (#43915)

* support complexd selected_rows kernel in yaml

* support configuring optimizer api in yaml

* fix data transform bug
上级 78023658
...@@ -224,7 +224,8 @@ add_custom_command( ...@@ -224,7 +224,8 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp}
${api_source_file} ${api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}" COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base} DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${api_gen_file}
${api_gen_base}
VERBATIM) VERBATIM)
# generate backward api # generate backward api
...@@ -240,7 +241,8 @@ add_custom_command( ...@@ -240,7 +241,8 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_source_file_tmp}
${bw_api_source_file} ${bw_api_source_file}
COMMENT "copy_if_different ${bw_api_header_file} ${bw_api_source_file}" COMMENT "copy_if_different ${bw_api_header_file} ${bw_api_source_file}"
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base} DEPENDS ${bw_api_yaml_file} ${legacy_bw_api_yaml_file} ${bw_api_gen_file}
${api_gen_base}
VERBATIM) VERBATIM)
# generate sparse api # generate sparse api
......
...@@ -32,237 +32,6 @@ limitations under the License. */ ...@@ -32,237 +32,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& learning_rate,
const Tensor& moment1,
const Tensor& moment2,
const Tensor& beta1_pow,
const Tensor& beta2_pow,
const paddle::optional<Tensor>& master_param,
const paddle::optional<Tensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "adam";
if (!phi::DenseTensor::classof(grad.impl().get())) {
kernel_name = "adam_dense_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {});
auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {});
auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {});
auto input_beta1_pow = PrepareData(beta1_pow, kernel.InputAt(5), {});
auto input_beta2_pow = PrepareData(beta2_pow, kernel.InputAt(6), {});
auto input_master_param = PrepareData(master_param, kernel.InputAt(7), {});
auto input_skip_update = PrepareData(skip_update, kernel.InputAt(8), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = input_param.get();
auto kernel_out_1 = input_moment1.get();
auto kernel_out_2 = input_moment2.get();
auto kernel_out_3 = input_beta1_pow.get();
auto kernel_out_4 = input_beta2_pow.get();
phi::DenseTensor* kernel_out_5 = nullptr;
if (input_master_param) {
kernel_out_5 = input_master_param.get_ptr();
}
auto input_meta_ref_master_param = MakeMetaTensor(input_master_param);
auto input_meta_ref_skip_update = MakeMetaTensor(input_skip_update);
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MetaTensor meta_out_3(kernel_out_3);
phi::MetaTensor meta_out_4(kernel_out_4);
phi::MetaTensor meta_out_5(kernel_out_5);
if (phi::DenseTensor::classof(grad.impl().get())) {
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
phi::AdamInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
MakeMetaTensor(*input_moment1),
MakeMetaTensor(*input_moment2),
MakeMetaTensor(*input_beta1_pow),
MakeMetaTensor(*input_beta2_pow),
input_meta_ref_master_param,
input_meta_ref_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
const paddle::optional<phi::DenseTensor>&,
const Scalar&,
const Scalar&,
const Scalar&,
bool,
int64_t,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_lr,
*input_moment1,
*input_moment2,
*input_beta1_pow,
*input_beta2_pow,
input_master_param,
input_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
} else {
auto input_grad = TensorToSelectedRows(grad);
phi::AdamInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
MakeMetaTensor(*input_moment1),
MakeMetaTensor(*input_moment2),
MakeMetaTensor(*input_beta1_pow),
MakeMetaTensor(*input_beta2_pow),
input_meta_ref_master_param,
input_meta_ref_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
const paddle::optional<phi::DenseTensor>&,
const Scalar&,
const Scalar&,
const Scalar&,
bool,
int64_t,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_lr,
*input_moment1,
*input_moment2,
*input_beta1_pow,
*input_beta2_pow,
input_master_param,
input_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
}
return api_output;
}
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl( std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
...@@ -1100,159 +869,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -1100,159 +869,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output; return api_output;
} }
std::tuple<Tensor, Tensor> sgd_impl(
const Tensor& param,
const Tensor& learning_rate,
const Tensor& grad,
const paddle::optional<Tensor>& master_param,
bool multi_precision) {
DataType kernel_data_type = ParseDataType(param);
auto kernel_key_set = ParseKernelKeyByInputArgs(param, learning_rate, grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
VLOG(6) << "sgd API kernel key: [" << kernel_key.backend() << ", "
<< kernel_key.layout() << ", " << kernel_data_type << "]";
const auto& param_tensor = param.impl();
std::string kernel_name = "sgd";
if (phi::DenseTensor::classof(param_tensor.get())) {
if (!phi::DenseTensor::classof(grad.impl().get())) {
kernel_name = "sgd_dense_param_sparse_grad";
}
} else {
kernel_name = "sgd_sparse_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name,
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto in_learning_rate =
PrepareData(learning_rate, kernel.InputAt(1), {false, true, true, true});
std::tuple<Tensor, Tensor> out;
std::get<0>(out) = param;
if (master_param) {
std::get<1>(out) = *master_param;
}
phi::MetaTensor meta_out_0(std::get<0>(out).impl().get());
phi::MetaTensor meta_out_1(master_param ? std::get<1>(out).impl().get()
: nullptr);
if (phi::DenseTensor::classof(param_tensor.get())) {
auto in_param = PrepareData(param, kernel.InputAt(0), {});
auto in_master_param_opt = PrepareData(master_param, kernel.InputAt(3), {});
auto master_param_meta_opt = MakeMetaTensor(in_master_param_opt);
phi::DenseTensor* kernel_out_0 =
SetKernelOutput(kernel_key.backend(), &std::get<0>(out));
phi::DenseTensor* kernel_out_1 =
master_param
? static_cast<phi::DenseTensor*>(std::get<1>(out).impl().get())
: nullptr;
if (phi::DenseTensor::classof(grad.impl().get())) {
auto in_grad = PrepareData(grad, kernel.InputAt(2), {});
SgdInferMeta(MakeMetaTensor(*in_param),
MakeMetaTensor(*in_learning_rate),
MakeMetaTensor(*in_grad),
master_param_meta_opt,
multi_precision,
&meta_out_0,
&meta_out_1);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
bool,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*in_param,
*in_learning_rate,
*in_grad,
in_master_param_opt,
multi_precision,
kernel_out_0,
kernel_out_1);
} else {
auto in_grad = TensorToSelectedRows(grad);
SgdInferMeta(MakeMetaTensor(*in_param),
MakeMetaTensor(*in_learning_rate),
MakeMetaTensor(*in_grad),
master_param_meta_opt,
multi_precision,
&meta_out_0,
&meta_out_1);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const paddle::optional<phi::DenseTensor>&,
bool,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*in_param,
*in_learning_rate,
*in_grad,
in_master_param_opt,
multi_precision,
kernel_out_0,
kernel_out_1);
}
} else {
auto in_param = TensorToSelectedRows(param);
auto in_grad = TensorToSelectedRows(grad);
auto in_master_param_opt = TensorToSelectedRows(master_param);
auto master_param_meta = MakeMetaTensor(in_master_param_opt);
phi::SelectedRows* kernel_out_0 =
SetSelectedRowsKernelOutput(kernel_key.backend(), &std::get<0>(out));
phi::SelectedRows* kernel_out_1 =
master_param
? static_cast<phi::SelectedRows*>(std::get<1>(out).impl().get())
: nullptr;
SgdInferMeta(MakeMetaTensor(*in_param),
MakeMetaTensor(*in_learning_rate),
MakeMetaTensor(*in_grad),
master_param_meta,
multi_precision,
&meta_out_0,
&meta_out_1);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const phi::SelectedRows&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const paddle::optional<phi::SelectedRows>&,
bool,
phi::SelectedRows*,
phi::SelectedRows*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*in_param,
*in_learning_rate,
*in_grad,
in_master_param_opt,
multi_precision,
kernel_out_0,
kernel_out_1);
}
return out;
}
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
// TODO(chenweihang): the original sum grad op can support higher-level // TODO(chenweihang): the original sum grad op can support higher-level
......
...@@ -31,24 +31,6 @@ namespace experimental { ...@@ -31,24 +31,6 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& learning_rate,
const Tensor& moment1,
const Tensor& moment2,
const Tensor& beta1_pow,
const Tensor& beta2_pow,
const paddle::optional<Tensor>& master_param,
const paddle::optional<Tensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl( std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
const Tensor& param, const Tensor& param,
const Tensor& grad, const Tensor& grad,
...@@ -132,13 +114,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -132,13 +114,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision, bool multi_precision,
float rescale_grad); float rescale_grad);
std::tuple<Tensor, Tensor> sgd_impl(
const Tensor& param,
const Tensor& learning_rate,
const Tensor& grad,
const paddle::optional<Tensor>& master_param,
bool multi_precision);
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
void add_n_grad_impl(const std::vector<Tensor>& x, void add_n_grad_impl(const std::vector<Tensor>& x,
......
...@@ -62,7 +62,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) { ...@@ -62,7 +62,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) {
/* ----------------- for infer_meta --------------------- */ /* ----------------- for infer_meta --------------------- */
phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) { phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
return phi::MetaTensor(tensor); return phi::MetaTensor(tensor);
} }
...@@ -94,10 +94,6 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -94,10 +94,6 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors; return meta_tensors;
} }
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor);
}
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor) { const paddle::optional<phi::SelectedRows>& tensor) {
if (tensor) { if (tensor) {
...@@ -106,10 +102,6 @@ phi::MetaTensor MakeMetaTensor( ...@@ -106,10 +102,6 @@ phi::MetaTensor MakeMetaTensor(
return phi::MetaTensor(); return phi::MetaTensor();
} }
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor) {
return phi::MetaTensor(tensor);
}
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
......
...@@ -47,7 +47,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor); ...@@ -47,7 +47,7 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor);
/* ----------------- for infer_meta --------------------- */ /* ----------------- for infer_meta --------------------- */
phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor); phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor);
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::DenseTensor>& tensor); const paddle::optional<phi::DenseTensor>& tensor);
...@@ -58,13 +58,9 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -58,13 +58,9 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
std::vector<phi::MetaTensor> MakeMetaTensor( std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors); const std::vector<phi::DenseTensor*>& tensors);
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor);
phi::MetaTensor MakeMetaTensor( phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor); const paddle::optional<phi::SelectedRows>& tensor);
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor);
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out); phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out);
......
...@@ -1370,8 +1370,8 @@ class SGDOptimizer(Optimizer): ...@@ -1370,8 +1370,8 @@ class SGDOptimizer(Optimizer):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if in_dygraph_mode(): if in_dygraph_mode():
_C_ops.final_state_sgd(param_and_grad[0], lr, param_and_grad[1], _C_ops.final_state_sgd_(param_and_grad[0], lr, param_and_grad[1],
master_weight, find_master) master_weight, find_master)
return None return None
if _in_legacy_dygraph(): if _in_legacy_dygraph():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight, _C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
......
...@@ -342,7 +342,7 @@ class Adam(Optimizer): ...@@ -342,7 +342,7 @@ class Adam(Optimizer):
_beta2 = self._beta2 if not isinstance( _beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0) self._beta2, Variable) else self._beta2.numpy().item(0)
_, _, _, _, _, _ = _C_ops.final_state_adam( _, _, _, _, _, _ = _C_ops.final_state_adam_(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, _beta1, beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, _beta1,
_beta2, self._epsilon, self._lazy_mode, 1000, find_master, _beta2, self._epsilon, self._lazy_mode, 1000, find_master,
......
...@@ -143,8 +143,8 @@ class SGD(Optimizer): ...@@ -143,8 +143,8 @@ class SGD(Optimizer):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
if in_dygraph_mode(): if in_dygraph_mode():
_C_ops.final_state_sgd(param_and_grad[0], lr, param_and_grad[1], _C_ops.final_state_sgd_(param_and_grad[0], lr, param_and_grad[1],
master_weight, find_master) master_weight, find_master)
return None return None
if _in_legacy_dygraph(): if _in_legacy_dygraph():
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight, _C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
......
...@@ -24,6 +24,11 @@ inplace_out_type_map = { ...@@ -24,6 +24,11 @@ inplace_out_type_map = {
"std::vector<Tensor>": "std::vector<Tensor>&" "std::vector<Tensor>": "std::vector<Tensor>&"
} }
inplace_optional_out_type_map = {
"Tensor": "paddle::optional<Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&"
}
class ForwardAPI(BaseAPI): class ForwardAPI(BaseAPI):
...@@ -80,7 +85,11 @@ class ForwardAPI(BaseAPI): ...@@ -80,7 +85,11 @@ class ForwardAPI(BaseAPI):
for i, out_type in enumerate(self.outputs['types']): for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0] out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map: if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type]) if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type])
else:
out_type_list.append(inplace_out_type_map[out_type])
else: else:
out_type_list.append(out_type) out_type_list.append(out_type)
...@@ -94,7 +103,11 @@ class ForwardAPI(BaseAPI): ...@@ -94,7 +103,11 @@ class ForwardAPI(BaseAPI):
for i, out_type in enumerate(self.outputs['types']): for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0] out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map: if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type]) if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type])
else:
out_type_list.append(inplace_out_type_map[out_type])
elif self.is_dygraph_api or out_name not in self.intermediate_outs: elif self.is_dygraph_api or out_name not in self.intermediate_outs:
out_type_list.append(out_type) out_type_list.append(out_type)
...@@ -120,16 +133,16 @@ class ForwardAPI(BaseAPI): ...@@ -120,16 +133,16 @@ class ForwardAPI(BaseAPI):
return 'return {' + ", ".join(selected_code) + '};' return 'return {' + ", ".join(selected_code) + '};'
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag) return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[ inplace_assign = " = " + self.inplace_map[
...@@ -137,7 +150,8 @@ class ForwardAPI(BaseAPI): ...@@ -137,7 +150,8 @@ class ForwardAPI(BaseAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {return_type} api_output{inplace_assign};""" {code_indent} {return_type} api_output{inplace_assign};"""
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
0] == 'dense' else 'SetSelectedRowsKernelOutput'
if return_type == 'std::vector<Tensor>': if return_type == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'][0] is not None, \ assert self.outputs['out_size_expr'][0] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
...@@ -155,7 +169,7 @@ class ForwardAPI(BaseAPI): ...@@ -155,7 +169,7 @@ class ForwardAPI(BaseAPI):
{code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]}); {code_indent} kernel_out->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][0]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" {code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
{code_indent} {return_type} api_output;""" {code_indent} {return_type} api_output;"""
...@@ -171,19 +185,27 @@ class ForwardAPI(BaseAPI): ...@@ -171,19 +185,27 @@ class ForwardAPI(BaseAPI):
output_create += 'Tensor(), ' output_create += 'Tensor(), '
output_create = output_create[:-2] + '};' output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)): for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
i] == 'dense' else 'SetSelectedRowsKernelOutput'
get_out_code = f"&std::get<{i}>(api_output)"
if self.outputs['names'][
i] in self.inplace_map and self.inplace_map[
self.outputs['names'][i]] in self.optional_vars:
get_out_code = f"std::get<{i}>(api_output).get_ptr()"
if output_type_list[i] == 'std::vector<Tensor>': if out_dtype_list[i] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'][i] is not None, \ assert self.outputs['out_size_expr'][i] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, &std::get<{i}>(api_output));""" {code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, {get_out_code});"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
if not inplace_flag and self.view_map is not None and self.outputs[ if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map: 'names'][i] in self.view_map:
......
...@@ -114,22 +114,24 @@ class BackwardAPI(BaseAPI): ...@@ -114,22 +114,24 @@ class BackwardAPI(BaseAPI):
return 'void' return 'void'
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = "" output_create = ""
if output_type_list[0] == 'std::vector<Tensor>': set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
0] == 'dense' else 'SetSelectedRowsKernelOutput'
if out_dtype_list[0] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \ assert self.outputs['out_size_expr'] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
...@@ -139,11 +141,13 @@ class BackwardAPI(BaseAPI): ...@@ -139,11 +141,13 @@ class BackwardAPI(BaseAPI):
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = "" output_create = ""
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(out_dtype_list):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
i] == 'dense' else 'SetSelectedRowsKernelOutput'
if out_type_item == 'Tensor': if out_type_item == 'Tensor':
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
......
...@@ -48,11 +48,17 @@ ...@@ -48,11 +48,17 @@
kernel : kernel :
func : adadelta func : adadelta
- api : adam - api : adam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta :
func : AdamInferMeta
kernel :
func : adam {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense},
adam_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}
data_type : param
optional : master_param, skip_update optional : master_param, skip_update
invoke : adam_impl(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, beta1, beta2, epsilon, lazy_mode, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow) inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs)
- api : adamax - api : adamax
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment, Tensor inf_norm, Tensor beta1_pow, float beta1, float beta2, float epsilon) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment, Tensor inf_norm, Tensor beta1_pow, float beta1, float beta2, float epsilon)
...@@ -1015,7 +1021,8 @@ ...@@ -1015,7 +1021,8 @@
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
func : isfinite, infinite_sr func : isfinite {dense -> dense},
infinite_sr {selected_rows -> selected_rows}
# isinf # isinf
- api : isinf - api : isinf
...@@ -1024,7 +1031,8 @@ ...@@ -1024,7 +1031,8 @@
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
func : isinf, isinf_sr func : isinf {dense -> dense},
isinf_sr {selected_rows -> selected_rows}
# isnan # isnan
- api : isnan - api : isnan
...@@ -1033,7 +1041,8 @@ ...@@ -1033,7 +1041,8 @@
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
func : isnan, isnan_sr func : isnan {dense -> dense},
isnan_sr {selected_rows -> selected_rows}
- api : kldiv_loss - api : kldiv_loss
args : (Tensor x, Tensor label, str reduction) args : (Tensor x, Tensor label, str reduction)
...@@ -1774,7 +1783,8 @@ ...@@ -1774,7 +1783,8 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : scale, scale_sr func : scale {dense -> dense},
scale_sr {selected_rows -> selected_rows}
inplace : (x -> out) inplace : (x -> out)
backward : scale_grad backward : scale_grad
...@@ -1829,11 +1839,20 @@ ...@@ -1829,11 +1839,20 @@
func : selu func : selu
backward : selu_grad backward : selu_grad
- api : sgd - api : sgd_
args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision) args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision)
output : Tensor(param_out), Tensor(master_param_out) output : Tensor(param_out), Tensor(master_param_out)
invoke : sgd_impl(param, learning_rate, grad, master_param, multi_precision) infer_meta :
func : SgdInferMeta
kernel :
func : sgd {dense, dense, dense, dense -> dense, dense},
sgd_dense_param_sparse_grad {dense, dense, selected_rows, dense -> dense, dense},
sgd_sparse_param_sparse_grad {selected_rows, dense, selected_rows, selected_rows -> selected_rows, selected_rows}
data_type : param
data_transform :
support_trans_dtype : learning_rate
optional : master_param optional : master_param
inplace : (param -> param_out), (master_param -> master_param_out)
- api : shape - api : shape
args : (Tensor input) args : (Tensor input)
...@@ -1841,7 +1860,8 @@ ...@@ -1841,7 +1860,8 @@
infer_meta : infer_meta :
func : ShapeInferMeta func : ShapeInferMeta
kernel : kernel :
func : shape, shape_sr func : shape {dense -> dense},
shape_sr {selected_rows -> selected_rows}
data_transform: data_transform:
skip_transform : input skip_transform : input
......
...@@ -31,18 +31,10 @@ class SparseAPI(ForwardAPI): ...@@ -31,18 +31,10 @@ class SparseAPI(ForwardAPI):
{super(SparseAPI, self).gene_api_declaration()} {super(SparseAPI, self).gene_api_declaration()}
""" """
def get_kernel_tensor_out_type(self, output_name):
sparse_type = 'TensorType::DENSE_TENSOR'
if output_name.endswith('@SparseCooTensor'):
sparse_type = 'TensorType::SPARSE_COO'
elif output_name.endswith('@SparseCsrTensor'):
sparse_type = 'TensorType::SPARSE_CSR'
return sparse_type
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
...@@ -54,7 +46,7 @@ class SparseAPI(ForwardAPI): ...@@ -54,7 +46,7 @@ class SparseAPI(ForwardAPI):
'sparse_csr': 'TensorType::SPARSE_CSR' 'sparse_csr': 'TensorType::SPARSE_CSR'
} }
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
...@@ -62,9 +54,9 @@ class SparseAPI(ForwardAPI): ...@@ -62,9 +54,9 @@ class SparseAPI(ForwardAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{return_type} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
auto* kernel_out = {set_out_func}(&api_output, {output_type_map[output_type_list[0]]});""" auto* kernel_out = SetSparseKernelOutput(&api_output, {output_type_map[out_dtype_list[0]]});"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
{return_type} api_output;""" {return_type} api_output;"""
...@@ -80,11 +72,11 @@ class SparseAPI(ForwardAPI): ...@@ -80,11 +72,11 @@ class SparseAPI(ForwardAPI):
output_create += 'Tensor(), ' output_create += 'Tensor(), '
output_create = output_create[:-2] + '};' output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)): for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
output_create = output_create + f""" output_create = output_create + f"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {output_type_map[output_type_list[i]]});""" auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -148,8 +140,7 @@ class SparseAPI(ForwardAPI): ...@@ -148,8 +140,7 @@ class SparseAPI(ForwardAPI):
def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False): def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
_, kernel_output_names, output_create = self.gene_output( _, kernel_output_names, output_create = self.gene_output(
self.kernel['dispatch'][kernel_name][1], 'SetSparseKernelOutput', self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag)
'', inplace_flag)
kernel_context_code = self.gen_sparse_kernel_context( kernel_context_code = self.gen_sparse_kernel_context(
kernel_output_names) kernel_output_names)
...@@ -189,7 +180,6 @@ class SparseAPI(ForwardAPI): ...@@ -189,7 +180,6 @@ class SparseAPI(ForwardAPI):
return " && ".join(condition_list) return " && ".join(condition_list)
def gene_dispatch_code(self, kernel_name, inplace_flag=False): def gene_dispatch_code(self, kernel_name, inplace_flag=False):
dispatch_code = ""
return f""" return f"""
if ({self.get_condition_code(kernel_name)}) {{ if ({self.get_condition_code(kernel_name)}) {{
{self.gen_sparse_kernel_code(kernel_name, inplace_flag)} {self.gen_sparse_kernel_code(kernel_name, inplace_flag)}
......
...@@ -48,9 +48,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -48,9 +48,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
return BackwardAPI.get_define_args(self) return BackwardAPI.get_define_args(self)
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
...@@ -61,19 +61,19 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -61,19 +61,19 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
'sparse_csr': 'TensorType::SPARSE_CSR' 'sparse_csr': 'TensorType::SPARSE_CSR'
} }
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
auto kernel_out = {set_out_func}({self.outputs['names'][0]}, {output_type_map[output_type_list[0]]});""" auto kernel_out = SetSparseKernelOutput({self.outputs['names'][0]}, {output_type_map[out_dtype_list[0]]});"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = "" output_create = ""
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(out_dtype_list):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
...@@ -82,7 +82,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -82,7 +82,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
*{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};""" *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i]}, {output_type_map[output_type_list[i]]});""" auto kernel_out_{i} = SetSparseKernelOutput({self.outputs['names'][i]}, {output_type_map[out_dtype_list[i]]});"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
......
...@@ -51,16 +51,16 @@ class StringsAPI(ForwardAPI): ...@@ -51,16 +51,16 @@ class StringsAPI(ForwardAPI):
return tensor_type_dict[kernel_tensor_out_type] return tensor_type_dict[kernel_tensor_out_type]
def gene_output(self, def gene_output(self,
output_type_list, out_dtype_list,
set_out_func, out_tensor_type_list=None,
code_indent, code_indent='',
inplace_flag=False): inplace_flag=False):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type(inplace_flag) return_type = self.get_return_type(inplace_flag)
if len(output_type_list) == 1: if len(out_dtype_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
kernel_tensor_out_type = self.get_kernel_tensor_out_type( kernel_tensor_out_type = self.get_kernel_tensor_out_type(
...@@ -71,13 +71,13 @@ class StringsAPI(ForwardAPI): ...@@ -71,13 +71,13 @@ class StringsAPI(ForwardAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{return_type} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>({set_out_func}(kernel_backend, &api_output, {kernel_tensor_out_type}));""" {tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(kernel_backend, &api_output, {kernel_tensor_out_type}));"""
elif len(output_type_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
{return_type} api_output;""" {return_type} api_output;"""
for i in range(len(output_type_list)): for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
kernel_tensor_out_type = self.get_kernel_tensor_out_type( kernel_tensor_out_type = self.get_kernel_tensor_out_type(
...@@ -89,7 +89,7 @@ class StringsAPI(ForwardAPI): ...@@ -89,7 +89,7 @@ class StringsAPI(ForwardAPI):
std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};""" std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{tensor_type}* kernel_out_{i} = dynamic_cast<{tensor_type}*>({set_out_func}(&std::get<{i}>(api_output), {kernel_tensor_out_type}));""" {tensor_type}* kernel_out_{i} = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(&std::get<{i}>(api_output), {kernel_tensor_out_type}));"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -174,7 +174,7 @@ class StringsAPI(ForwardAPI): ...@@ -174,7 +174,7 @@ class StringsAPI(ForwardAPI):
input_tensors, kernel_args, kernel_signature = self.get_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_kernel_args(
code_indent) code_indent)
outputs_args, kernel_output_names, output_create = self.gene_output( outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetStringsKernelOutput', '', inplace_flag) self.outputs['types'], None, '', inplace_flag)
return f""" return f"""
// 1. Get kernel signature and kernel // 1. Get kernel signature and kernel
...@@ -252,11 +252,6 @@ class StringsAPI(ForwardAPI): ...@@ -252,11 +252,6 @@ class StringsAPI(ForwardAPI):
kernel_select_code = kernel_key_item_init + kernel_select_code kernel_select_code = kernel_key_item_init + kernel_select_code
if len(input_names) > 0: if len(input_names) > 0:
if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f"""
KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)});
"""
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册