未验证 提交 55aaeb39 编写于 作者: Z zyfncg 提交者: GitHub

Fix some problem of kernel fallback in C++ API (#44681)

* support auto fallback to  cpu kernel for cusom device

* fix some problem of kernel fallback
上级 2cec4c88
......@@ -73,8 +73,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
}
}
std::string kernel_name = "adamw";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;
......@@ -232,8 +233,9 @@ Tensor conv2d_impl(const Tensor& input,
VLOG(6) << "conv2d API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"conv2d", {kernel_backend, kernel_layout, kernel_data_type}, true);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "conv2d API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
......@@ -334,8 +336,9 @@ Tensor conv3d_impl(const Tensor& input,
VLOG(6) << "conv3d API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"conv3d", {kernel_backend, kernel_layout, kernel_data_type}, true);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "conv3d API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
......@@ -437,8 +440,9 @@ void conv2d_grad_impl(const Tensor& input,
VLOG(6) << "conv2d_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"conv2d_grad", {kernel_backend, kernel_layout, kernel_data_type}, true);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "conv2d_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
......@@ -538,8 +542,9 @@ void conv3d_grad_impl(const Tensor& input,
VLOG(6) << "conv3d_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"conv3d_grad", {kernel_backend, kernel_layout, kernel_data_type}, true);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "conv3d_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
......@@ -624,10 +629,11 @@ Tensor embedding_impl(const Tensor& x,
Tensor api_output;
if (phi::DenseTensor::classof(weight.impl().get())) {
const auto& kernel =
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"embedding",
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "embedding API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
......@@ -652,10 +658,11 @@ Tensor embedding_impl(const Tensor& x,
(*kernel_fn)(*dev_ctx, *input_x, *input_weight, padding_idx, kernel_out);
}
} else {
const auto& kernel =
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"sparse_weight_embedding",
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "sparse_weight_embedding API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
......@@ -693,8 +700,9 @@ std::vector<Tensor> split_impl(const Tensor& x,
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;
......@@ -774,8 +782,9 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
if (grad.is_selected_rows()) {
kernel_name = "momentum_dense_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;
......@@ -906,8 +915,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"batch_norm", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "batch_norm API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "batch_norm API kernel: " << kernel;
......@@ -1004,8 +1014,9 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
phi::dtype::ToComplex(out_grad.dtype())};
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"imag_grad", kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "imag_grad API kernel key: " << kernel_key;
VLOG(6) << "imag_grad API kernel: " << kernel;
......@@ -1042,10 +1053,11 @@ void embedding_grad_impl(const Tensor& x,
if (phi::DenseTensor::classof(weight.impl().get())) {
std::string kernel_name =
sparse ? "embedding_sparse_grad" : "embedding_grad";
const auto& kernel =
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name,
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
......@@ -1094,10 +1106,11 @@ void embedding_grad_impl(const Tensor& x,
} else {
std::string kernel_name = sparse ? "sparse_weight_embedding_sparse_grad"
: "sparse_weight_embedding_grad";
const auto& kernel =
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name,
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
......@@ -1148,8 +1161,9 @@ void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
phi::dtype::ToComplex(out_grad.dtype())};
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"real_grad", kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "real_grad API kernel key: " << kernel_key;
VLOG(6) << "real_grad API kernel: " << kernel;
......
......@@ -294,5 +294,30 @@ paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
return paddle::none;
}
void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend,
phi::DenseTensor* out) {
if (tensor) {
*out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend));
}
}
void TransDataBackend(const std::vector<phi::DenseTensor*>& tensors,
Backend target_backend,
std::vector<phi::DenseTensor*> outs) {
size_t n = tensors.size();
for (size_t i = 0; i < n; ++i) {
TransDataBackend(tensors[i], target_backend, outs[i]);
}
}
void TransDataBackend(const phi::SelectedRows* tensor,
Backend target_backend,
phi::SelectedRows* out) {
if (tensor) {
TransDataBackend(&tensor->value(), target_backend, out->mutable_value());
}
}
} // namespace experimental
} // namespace paddle
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/selected_rows.h"
namespace paddle {
namespace experimental {
......@@ -81,5 +82,17 @@ paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
void TransDataBackend(const phi::DenseTensor* tensor,
Backend target_backend,
phi::DenseTensor* out);
void TransDataBackend(const std::vector<phi::DenseTensor*>& tensor,
Backend target_backend,
std::vector<phi::DenseTensor*> out);
void TransDataBackend(const phi::SelectedRows* tensor,
Backend target_backend,
phi::SelectedRows* out);
} // namespace experimental
} // namespace paddle
......@@ -38,8 +38,9 @@ Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "add API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
......@@ -95,8 +96,9 @@ Tensor to_sparse_csr_impl(const Tensor& x) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "add API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
......@@ -157,8 +159,9 @@ Tensor to_dense_impl(const Tensor& x) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
const auto& kernel = kernel_result.kernel;
VLOG(6) << "add API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
......
......@@ -691,15 +691,21 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], out_tensor_type_list, code_indent,
inplace_flag)
fallback_kernel_output_trans = ""
for kernel_out in outputs_args:
fallback_kernel_output_trans += (f"""
{code_indent} TransDataBackend({kernel_out}, kernel_backend, {kernel_out});"""
)
cudnn_args = '' if self.kernel[
'use_gpudnn'] == 'false' else ', ' + self.kernel['use_gpudnn']
return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
{code_indent} const auto& kernel = kernel_result.kernel;
{code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
{input_tensors}
{output_create}
{self.gene_infer_meta(kernel_output_names, code_indent)}
......@@ -708,7 +714,10 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} {{
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{kernel_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} (*kernel_fn)({kernel_args}, {", ".join(outputs_args)});
{code_indent} }}
{code_indent} if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans}
{code_indent} }}
{code_indent} {self.gene_return_code()}"""
......
......@@ -137,13 +137,13 @@ class ForwardAPI(BaseAPI):
out_tensor_type_list=None,
code_indent='',
inplace_flag=False):
kernel_output = ""
kernel_output = []
output_names = []
output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(out_dtype_list) == 1:
kernel_output = 'kernel_out'
kernel_output.append('kernel_out')
output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[
self.outputs['names'][0]] if inplace_flag and self.outputs[
......@@ -186,7 +186,7 @@ class ForwardAPI(BaseAPI):
output_create = output_create[:-2] + '};'
for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}')
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
i] == 'dense' else 'SetSelectedRowsKernelOutput'
......@@ -214,7 +214,6 @@ class ForwardAPI(BaseAPI):
{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
......
......@@ -118,12 +118,12 @@ class BackwardAPI(BaseAPI):
out_tensor_type_list=None,
code_indent='',
inplace_flag=False):
kernel_output = ""
kernel_output = []
output_names = []
output_create = ""
if len(out_dtype_list) == 1:
kernel_output = 'kernel_out'
kernel_output.append('kernel_out')
output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
......@@ -144,7 +144,7 @@ class BackwardAPI(BaseAPI):
elif len(out_dtype_list) > 1:
output_create = ""
for i, out_type_item in enumerate(out_dtype_list):
kernel_output = kernel_output + f'kernel_out_{i}, '
kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}')
set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
i] == 'dense' else 'SetSelectedRowsKernelOutput'
......@@ -168,7 +168,6 @@ class BackwardAPI(BaseAPI):
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(&{self.outputs['names'][i]});"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
......
......@@ -36,7 +36,7 @@ class SparseAPI(ForwardAPI):
out_tensor_type_list=None,
code_indent='',
inplace_flag=False):
kernel_output = ""
kernel_output = []
output_names = []
output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag)
......@@ -47,7 +47,7 @@ class SparseAPI(ForwardAPI):
}
if len(out_dtype_list) == 1:
kernel_output = 'kernel_out'
kernel_output.append('kernel_out')
output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
......@@ -73,12 +73,11 @@ class SparseAPI(ForwardAPI):
output_create = output_create[:-2] + '};'
for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}')
output_create = output_create + f"""
auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
......@@ -147,11 +146,12 @@ class SparseAPI(ForwardAPI):
self.gene_return_code()) == 0 else " " + self.gene_return_code()
return f"""
VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto phi_kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
const auto& phi_kernel = kernel_result.kernel;
VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
auto kernel_context = phi::KernelContext(dev_ctx);
{output_create}
{kernel_context_code}
......
......@@ -52,7 +52,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
out_tensor_type_list=None,
code_indent='',
inplace_flag=False):
kernel_output = ""
kernel_output = []
output_names = []
output_create = ""
output_type_map = {
......@@ -62,7 +62,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
}
if len(out_dtype_list) == 1:
kernel_output = 'kernel_out'
kernel_output.append('kernel_out')
output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
......@@ -74,7 +74,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
output_create = ""
for i, out_type_item in enumerate(out_dtype_list):
kernel_output = kernel_output + f'kernel_out_{i}, '
kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
......@@ -84,7 +84,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
output_create = output_create + f"""
auto kernel_out_{i} = SetSparseKernelOutput({self.outputs['names'][i]}, {output_type_map[out_dtype_list[i]]});"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
......
......@@ -55,13 +55,13 @@ class StringsAPI(ForwardAPI):
out_tensor_type_list=None,
code_indent='',
inplace_flag=False):
kernel_output = ""
kernel_output = []
output_names = []
output_create = ""
return_type = self.get_return_type(inplace_flag)
if len(out_dtype_list) == 1:
kernel_output = 'kernel_out'
kernel_output.append('kernel_out')
output_names.append('kernel_out')
kernel_tensor_out_type = self.get_kernel_tensor_out_type(
self.outputs['names'][0])
......@@ -78,7 +78,7 @@ class StringsAPI(ForwardAPI):
{return_type} api_output;"""
for i in range(len(out_dtype_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}')
kernel_tensor_out_type = self.get_kernel_tensor_out_type(
self.outputs['names'][i])
......@@ -91,7 +91,6 @@ class StringsAPI(ForwardAPI):
output_create = output_create + f"""
{tensor_type}* kernel_out_{i} = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(&std::get<{i}>(api_output), {kernel_tensor_out_type}));"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
......@@ -178,13 +177,14 @@ class StringsAPI(ForwardAPI):
return f"""
// 1. Get kernel signature and kernel
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
VLOG(6) << "{self.api} api strings kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "{self.api} api strings kernel: " << kernel;
// 2. Get Device Context and input
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
{input_tensors}
// 3. Set output
......@@ -195,7 +195,7 @@ class StringsAPI(ForwardAPI):
{code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} (*kernel_fn)({kernel_args}, {", ".join(outputs_args)});
{code_indent} {self.gene_return_code()}"""
......
......@@ -84,7 +84,7 @@ bool KernelFactory::HasKernel(const std::string& kernel_name,
return true;
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelResult KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name,
const KernelKey& kernel_key,
bool use_gpudnn) const {
......@@ -104,7 +104,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
}
if (kernel_iter != iter->second.end()) {
return kernel_iter->second;
return {kernel_iter->second, false};
}
LOG(WARNING) << "The cudnn kernel for [" << kernel_name
<< "] is not registered.";
......@@ -120,7 +120,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
kernel_iter = iter->second.find(any_layout_kernel_key);
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
bool has_fallback_cpu = false;
if (kernel_iter == iter->second.end()) {
// Fallback CPU backend
phi::KernelKey cpu_kernel_key(
......@@ -132,8 +132,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}
has_fallback_cpu = true;
}
#endif
PADDLE_ENFORCE_NE(
kernel_iter,
......@@ -143,16 +143,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
kernel_key,
kernel_name));
return kernel_iter->second;
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name,
Backend backend,
DataLayout layout,
DataType dtype) const {
return SelectKernelOrThrowError(kernel_name,
KernelKey(backend, layout, dtype));
return {kernel_iter->second, has_fallback_cpu};
}
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
......
......@@ -252,6 +252,14 @@ using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
struct KernelResult {
KernelResult(const Kernel& kernel, bool fallback_cpu)
: kernel(kernel), has_fallback_cpu(fallback_cpu) {}
const Kernel& kernel;
bool has_fallback_cpu = false;
};
/**
* Note: Each Computation need a basic kernel map that named by kernel_name.
* Such as for scale op, KernelMap contains a `scale` kernel map,
......@@ -268,14 +276,9 @@ class KernelFactory {
return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end();
}
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key,
bool use_gpudnn = false) const;
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
Backend backend,
DataLayout layout,
DataType dtype) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key,
bool use_gpudnn = false) const;
bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
......
......@@ -51,8 +51,9 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x,
kernel_data_type = kernel_key.dtype();
}
}
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel;
......@@ -221,8 +222,9 @@ Tensor scale_switch_case(const Tensor& x,
kernel_data_type = kernel_key.dtype();
}
}
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel;
......
......@@ -192,8 +192,9 @@ TEST(CustomKernel, custom_kernel_dot) {
fake_dot_kernels.end());
// 4.kernel select
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, phi::KernelKey(backend, layout, phi::DataType::UINT8));
const auto& kernel = kernel_result.kernel;
// 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册