diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 362c9606ebadf59d7f156496d757b467c16071be..88fefb8eac99da13d921811f02430c8e2a78290d 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -73,8 +73,9 @@ std::tuple adamw_impl( } } std::string kernel_name = "adamw"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, {kernel_backend, kernel_layout, kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; VLOG(6) << kernel_name << " API kernel: " << kernel; @@ -232,8 +233,9 @@ Tensor conv2d_impl(const Tensor& input, VLOG(6) << "conv2d API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "conv2d", {kernel_backend, kernel_layout, kernel_data_type}, true); + const auto& kernel = kernel_result.kernel; VLOG(6) << "conv2d API kernel: " << kernel; auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); @@ -334,8 +336,9 @@ Tensor conv3d_impl(const Tensor& input, VLOG(6) << "conv3d API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "conv3d", {kernel_backend, kernel_layout, kernel_data_type}, true); + const auto& kernel = kernel_result.kernel; VLOG(6) << "conv3d API kernel: " << kernel; auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); @@ -437,8 +440,9 @@ void conv2d_grad_impl(const Tensor& input, VLOG(6) << "conv2d_grad API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "conv2d_grad", {kernel_backend, kernel_layout, kernel_data_type}, true); + const auto& kernel = kernel_result.kernel; VLOG(6) << "conv2d_grad API kernel: " << kernel; auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); @@ -538,8 +542,9 @@ void conv3d_grad_impl(const Tensor& input, VLOG(6) << "conv3d_grad API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "conv3d_grad", {kernel_backend, kernel_layout, kernel_data_type}, true); + const auto& kernel = kernel_result.kernel; VLOG(6) << "conv3d_grad API kernel: " << kernel; auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); @@ -624,10 +629,11 @@ Tensor embedding_impl(const Tensor& x, Tensor api_output; if (phi::DenseTensor::classof(weight.impl().get())) { - const auto& kernel = + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "embedding", {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "embedding API kernel: " << kernel; auto input_x = PrepareData(x, kernel.InputAt(0), {}); @@ -652,10 +658,11 @@ Tensor embedding_impl(const Tensor& x, (*kernel_fn)(*dev_ctx, *input_x, *input_weight, padding_idx, kernel_out); } } else { - const auto& kernel = + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "sparse_weight_embedding", {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "sparse_weight_embedding API kernel: " << kernel; auto input_x = PrepareData(x, kernel.InputAt(0), {}); @@ -693,8 +700,9 @@ std::vector split_impl(const Tensor& x, DataLayout kernel_layout = kernel_key.layout(); DataType kernel_data_type = kernel_key.dtype(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "split", {kernel_backend, kernel_layout, kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "split API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; VLOG(6) << "split API kernel: " << kernel; @@ -774,8 +782,9 @@ std::tuple momentum_impl( if (grad.is_selected_rows()) { kernel_name = "momentum_dense_param_sparse_grad"; } - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, {kernel_backend, kernel_layout, kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; VLOG(6) << kernel_name << " API kernel: " << kernel; @@ -906,8 +915,9 @@ std::tuple batch_norm_impl( } } - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "batch_norm", {kernel_backend, kernel_layout, kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "batch_norm API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; VLOG(6) << "batch_norm API kernel: " << kernel; @@ -1004,8 +1014,9 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) { phi::KernelKey kernel_key{ParseBackend(out_grad), out_grad.layout(), phi::dtype::ToComplex(out_grad.dtype())}; - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "imag_grad", kernel_key); + const auto& kernel = kernel_result.kernel; VLOG(6) << "imag_grad API kernel key: " << kernel_key; VLOG(6) << "imag_grad API kernel: " << kernel; @@ -1042,10 +1053,11 @@ void embedding_grad_impl(const Tensor& x, if (phi::DenseTensor::classof(weight.impl().get())) { std::string kernel_name = sparse ? "embedding_sparse_grad" : "embedding_grad"; - const auto& kernel = + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel: " << kernel; auto input_x = PrepareData(x, kernel.InputAt(0), {}); @@ -1094,10 +1106,11 @@ void embedding_grad_impl(const Tensor& x, } else { std::string kernel_name = sparse ? "sparse_weight_embedding_sparse_grad" : "sparse_weight_embedding_grad"; - const auto& kernel = + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel: " << kernel; auto input_x = PrepareData(x, kernel.InputAt(0), {}); @@ -1148,8 +1161,9 @@ void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) { phi::KernelKey kernel_key{ParseBackend(out_grad), out_grad.layout(), phi::dtype::ToComplex(out_grad.dtype())}; - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "real_grad", kernel_key); + const auto& kernel = kernel_result.kernel; VLOG(6) << "real_grad API kernel key: " << kernel_key; VLOG(6) << "real_grad API kernel: " << kernel; diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 58795c0f06381dbe02c589a1edabed2997d8e570..3b44b1876e20d132118e11aeb42d49d1fcc34649 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -294,5 +294,30 @@ paddle::optional> PrepareData( return paddle::none; } +void TransDataBackend(const phi::DenseTensor* tensor, + Backend target_backend, + phi::DenseTensor* out) { + if (tensor) { + *out = TransDataPlace(*tensor, phi::TransToPhiPlace(target_backend)); + } +} + +void TransDataBackend(const std::vector& tensors, + Backend target_backend, + std::vector outs) { + size_t n = tensors.size(); + for (size_t i = 0; i < n; ++i) { + TransDataBackend(tensors[i], target_backend, outs[i]); + } +} + +void TransDataBackend(const phi::SelectedRows* tensor, + Backend target_backend, + phi::SelectedRows* out) { + if (tensor) { + TransDataBackend(&tensor->value(), target_backend, out->mutable_value()); + } +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 3feba2465f61bfbc38edd06ed1a84735cd0817fa..18ec2f639778607ac2be7078972b920970584dfe 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/core/selected_rows.h" namespace paddle { namespace experimental { @@ -81,5 +82,17 @@ paddle::optional> PrepareData( const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag); +void TransDataBackend(const phi::DenseTensor* tensor, + Backend target_backend, + phi::DenseTensor* out); + +void TransDataBackend(const std::vector& tensor, + Backend target_backend, + std::vector out); + +void TransDataBackend(const phi::SelectedRows* tensor, + Backend target_backend, + phi::SelectedRows* out); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/sparse_api_custom_impl.cc b/paddle/phi/api/lib/sparse_api_custom_impl.cc index 0b93c96e7f81d9afd1a22e6ce786c96ae3e84beb..73f5b28f459072b60ef178596a115cc5a118e84a 100644 --- a/paddle/phi/api/lib/sparse_api_custom_impl.cc +++ b/paddle/phi/api/lib/sparse_api_custom_impl.cc @@ -38,8 +38,9 @@ Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) { auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, kernel_key); + const auto& kernel = kernel_result.kernel; VLOG(6) << "add API kernel key: " << kernel_key; VLOG(6) << "to API kernel: " << kernel; @@ -95,8 +96,9 @@ Tensor to_sparse_csr_impl(const Tensor& x) { auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, kernel_key); + const auto& kernel = kernel_result.kernel; VLOG(6) << "add API kernel key: " << kernel_key; VLOG(6) << "to API kernel: " << kernel; @@ -157,8 +159,9 @@ Tensor to_dense_impl(const Tensor& x) { auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, kernel_key); + const auto& kernel = kernel_result.kernel; VLOG(6) << "add API kernel key: " << kernel_key; VLOG(6) << "to API kernel: " << kernel; diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index f93046d26279fa3094fee9621cbe87b00b3b7d07..3c190fd7e9e8e403290bb6c8faf5cd5ea6a743e6 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -691,15 +691,21 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d outputs_args, kernel_output_names, output_create = self.gene_output( self.outputs['types'], out_tensor_type_list, code_indent, inplace_flag) + fallback_kernel_output_trans = "" + for kernel_out in outputs_args: + fallback_kernel_output_trans += (f""" +{code_indent} TransDataBackend({kernel_out}, kernel_backend, {kernel_out});""" + ) cudnn_args = '' if self.kernel[ 'use_gpudnn'] == 'false' else ', ' + self.kernel['use_gpudnn'] return f""" {code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; -{code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( +{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args}); +{code_indent} const auto& kernel = kernel_result.kernel; {code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel; -{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); +{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); {input_tensors} {output_create} {self.gene_infer_meta(kernel_output_names, code_indent)} @@ -708,7 +714,10 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn(); {code_indent} {{ {code_indent} paddle::platform::RecordEvent kernel_record_event(\"{kernel_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1); -{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); +{code_indent} (*kernel_fn)({kernel_args}, {", ".join(outputs_args)}); +{code_indent} }} +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{fallback_kernel_output_trans} {code_indent} }} {code_indent} {self.gene_return_code()}""" diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 0893d0d5578f990396146fe16dfd794348db6e21..64c2fac85cd0028e684dde86fa0a88f2671ae837 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -137,13 +137,13 @@ class ForwardAPI(BaseAPI): out_tensor_type_list=None, code_indent='', inplace_flag=False): - kernel_output = "" + kernel_output = [] output_names = [] output_create = "" return_type = self.get_return_type_with_intermediate(inplace_flag) if len(out_dtype_list) == 1: - kernel_output = 'kernel_out' + kernel_output.append('kernel_out') output_names.append('kernel_out') inplace_assign = " = " + self.inplace_map[ self.outputs['names'][0]] if inplace_flag and self.outputs[ @@ -186,7 +186,7 @@ class ForwardAPI(BaseAPI): output_create = output_create[:-2] + '};' for i in range(len(out_dtype_list)): - kernel_output = kernel_output + f'kernel_out_{i}, ' + kernel_output.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}') set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[ i] == 'dense' else 'SetSelectedRowsKernelOutput' @@ -214,7 +214,6 @@ class ForwardAPI(BaseAPI): {code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); {code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" - kernel_output = kernel_output[:-2] else: raise ValueError( "{} : Output error: the output should not be empty.".format( diff --git a/paddle/phi/api/yaml/generator/backward_api_gen.py b/paddle/phi/api/yaml/generator/backward_api_gen.py index 67d47a8ec7432c6c5f5873dfecd796d807a094b1..cb57b04459787b5ce19220ea68c1d8cf25780ddd 100644 --- a/paddle/phi/api/yaml/generator/backward_api_gen.py +++ b/paddle/phi/api/yaml/generator/backward_api_gen.py @@ -118,12 +118,12 @@ class BackwardAPI(BaseAPI): out_tensor_type_list=None, code_indent='', inplace_flag=False): - kernel_output = "" + kernel_output = [] output_names = [] output_create = "" if len(out_dtype_list) == 1: - kernel_output = 'kernel_out' + kernel_output.append('kernel_out') output_names.append('kernel_out') inplace_assign = " = " + self.inplace_map[self.outputs['names'][ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ @@ -144,7 +144,7 @@ class BackwardAPI(BaseAPI): elif len(out_dtype_list) > 1: output_create = "" for i, out_type_item in enumerate(out_dtype_list): - kernel_output = kernel_output + f'kernel_out_{i}, ' + kernel_output.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}') set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[ i] == 'dense' else 'SetSelectedRowsKernelOutput' @@ -168,7 +168,6 @@ class BackwardAPI(BaseAPI): output_create = output_create + f""" {code_indent} auto kernel_out_{i} = {set_out_func}(&{self.outputs['names'][i]});""" - kernel_output = kernel_output[:-2] else: raise ValueError( "{} : Output error: the output should not be empty.".format( diff --git a/paddle/phi/api/yaml/generator/sparse_api_gen.py b/paddle/phi/api/yaml/generator/sparse_api_gen.py index 69bf6950cd822813f3cdcfa7128f281dea8ead6f..ac98c78f58a3f29cec38ef3e73bc4e695bef45c0 100644 --- a/paddle/phi/api/yaml/generator/sparse_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_api_gen.py @@ -36,7 +36,7 @@ class SparseAPI(ForwardAPI): out_tensor_type_list=None, code_indent='', inplace_flag=False): - kernel_output = "" + kernel_output = [] output_names = [] output_create = "" return_type = self.get_return_type_with_intermediate(inplace_flag) @@ -47,7 +47,7 @@ class SparseAPI(ForwardAPI): } if len(out_dtype_list) == 1: - kernel_output = 'kernel_out' + kernel_output.append('kernel_out') output_names.append('kernel_out') inplace_assign = " = " + self.inplace_map[self.outputs['names'][ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ @@ -73,12 +73,11 @@ class SparseAPI(ForwardAPI): output_create = output_create[:-2] + '};' for i in range(len(out_dtype_list)): - kernel_output = kernel_output + f'kernel_out_{i}, ' + kernel_output.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}') output_create = output_create + f""" auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});""" - kernel_output = kernel_output[:-2] else: raise ValueError( "{} : Output error: the output should not be empty.".format( @@ -147,11 +146,12 @@ class SparseAPI(ForwardAPI): self.gene_return_code()) == 0 else " " + self.gene_return_code() return f""" VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; - auto phi_kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); + const auto& phi_kernel = kernel_result.kernel; VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); auto kernel_context = phi::KernelContext(dev_ctx); {output_create} {kernel_context_code} diff --git a/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py b/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py index f3172a23cb9916dda7403756156c2b19d7ad8ff6..4692ed64513eaee2fffde172a7f513bf67a5610c 100644 --- a/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py @@ -52,7 +52,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): out_tensor_type_list=None, code_indent='', inplace_flag=False): - kernel_output = "" + kernel_output = [] output_names = [] output_create = "" output_type_map = { @@ -62,7 +62,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): } if len(out_dtype_list) == 1: - kernel_output = 'kernel_out' + kernel_output.append('kernel_out') output_names.append('kernel_out') inplace_assign = " = " + self.inplace_map[self.outputs['names'][ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ @@ -74,7 +74,7 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): output_create = "" for i, out_type_item in enumerate(out_dtype_list): - kernel_output = kernel_output + f'kernel_out_{i}, ' + kernel_output.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}') if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][i] in self.inplace_map: @@ -84,7 +84,6 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): output_create = output_create + f""" auto kernel_out_{i} = SetSparseKernelOutput({self.outputs['names'][i]}, {output_type_map[out_dtype_list[i]]});""" - kernel_output = kernel_output[:-2] else: raise ValueError( "{} : Output error: the output should not be empty.".format( diff --git a/paddle/phi/api/yaml/generator/strings_api_gen.py b/paddle/phi/api/yaml/generator/strings_api_gen.py index bb5a7a2413d8e1d39f586b43f1b7b91e38297691..18c23f10baf2df2a531c335601f64efeeed8723e 100644 --- a/paddle/phi/api/yaml/generator/strings_api_gen.py +++ b/paddle/phi/api/yaml/generator/strings_api_gen.py @@ -55,13 +55,13 @@ class StringsAPI(ForwardAPI): out_tensor_type_list=None, code_indent='', inplace_flag=False): - kernel_output = "" + kernel_output = [] output_names = [] output_create = "" return_type = self.get_return_type(inplace_flag) if len(out_dtype_list) == 1: - kernel_output = 'kernel_out' + kernel_output.append('kernel_out') output_names.append('kernel_out') kernel_tensor_out_type = self.get_kernel_tensor_out_type( self.outputs['names'][0]) @@ -78,7 +78,7 @@ class StringsAPI(ForwardAPI): {return_type} api_output;""" for i in range(len(out_dtype_list)): - kernel_output = kernel_output + f'kernel_out_{i}, ' + kernel_output.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}') kernel_tensor_out_type = self.get_kernel_tensor_out_type( self.outputs['names'][i]) @@ -91,7 +91,6 @@ class StringsAPI(ForwardAPI): output_create = output_create + f""" {tensor_type}* kernel_out_{i} = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(&std::get<{i}>(api_output), {kernel_tensor_out_type}));""" - kernel_output = kernel_output[:-2] else: raise ValueError( "{} : Output error: the output should not be empty.".format( @@ -178,13 +177,14 @@ class StringsAPI(ForwardAPI): return f""" // 1. Get kernel signature and kernel - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); VLOG(6) << "{self.api} api strings kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "{self.api} api strings kernel: " << kernel; // 2. Get Device Context and input - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); {input_tensors} // 3. Set output @@ -195,7 +195,7 @@ class StringsAPI(ForwardAPI): {code_indent} using kernel_signature = {kernel_signature}; {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn(); -{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); +{code_indent} (*kernel_fn)({kernel_args}, {", ".join(outputs_args)}); {code_indent} {self.gene_return_code()}""" diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index ae9c16e0cc7106bcd15fe24bafc0cf3954b1bc7f..3bee07f8a3f15b69e10e3f7709e3b725a3405dce 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -84,7 +84,7 @@ bool KernelFactory::HasKernel(const std::string& kernel_name, return true; } -const Kernel& KernelFactory::SelectKernelOrThrowError( +KernelResult KernelFactory::SelectKernelOrThrowError( const std::string& kernel_name, const KernelKey& kernel_key, bool use_gpudnn) const { @@ -104,7 +104,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()}); } if (kernel_iter != iter->second.end()) { - return kernel_iter->second; + return {kernel_iter->second, false}; } LOG(WARNING) << "The cudnn kernel for [" << kernel_name << "] is not registered."; @@ -120,7 +120,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( kernel_iter = iter->second.find(any_layout_kernel_key); } -#ifdef PADDLE_WITH_CUSTOM_DEVICE + bool has_fallback_cpu = false; if (kernel_iter == iter->second.end()) { // Fallback CPU backend phi::KernelKey cpu_kernel_key( @@ -132,8 +132,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); kernel_iter = iter->second.find(any_layout_kernel_key); } + has_fallback_cpu = true; } -#endif PADDLE_ENFORCE_NE( kernel_iter, @@ -143,16 +143,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( kernel_key, kernel_name)); - return kernel_iter->second; -} - -const Kernel& KernelFactory::SelectKernelOrThrowError( - const std::string& kernel_name, - Backend backend, - DataLayout layout, - DataType dtype) const { - return SelectKernelOrThrowError(kernel_name, - KernelKey(backend, layout, dtype)); + return {kernel_iter->second, has_fallback_cpu}; } const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index c4c8274db976cc40be7f084065a94f808717b874..59e91451fff750c0c0cd115259c354c410f9a0d6 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -252,6 +252,14 @@ using KernelKeyMap = paddle::flat_hash_map; using KernelNameMap = paddle::flat_hash_map; +struct KernelResult { + KernelResult(const Kernel& kernel, bool fallback_cpu) + : kernel(kernel), has_fallback_cpu(fallback_cpu) {} + + const Kernel& kernel; + bool has_fallback_cpu = false; +}; + /** * Note: Each Computation need a basic kernel map that named by kernel_name. * Such as for scale op, KernelMap contains a `scale` kernel map, @@ -268,14 +276,9 @@ class KernelFactory { return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end(); } - const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, - const KernelKey& kernel_key, - bool use_gpudnn = false) const; - - const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, - Backend backend, - DataLayout layout, - DataType dtype) const; + KernelResult SelectKernelOrThrowError(const std::string& kernel_name, + const KernelKey& kernel_key, + bool use_gpudnn = false) const; bool HasKernel(const std::string& kernel_name, const KernelKey& kernel_key) const; diff --git a/paddle/phi/tests/api/scale_api.h b/paddle/phi/tests/api/scale_api.h index 322f7b27abdb132c34f2fd2b04176a84954b7f6d..0e42b7f2a18fbb6cfd57a5841c760c815b98590c 100644 --- a/paddle/phi/tests/api/scale_api.h +++ b/paddle/phi/tests/api/scale_api.h @@ -51,8 +51,9 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x, kernel_data_type = kernel_key.dtype(); } } - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "scale", {kernel_backend, kernel_layout, kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "scale API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; VLOG(6) << "scale API kernel: " << kernel; @@ -221,8 +222,9 @@ Tensor scale_switch_case(const Tensor& x, kernel_data_type = kernel_key.dtype(); } } - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( "scale", {kernel_backend, kernel_layout, kernel_data_type}); + const auto& kernel = kernel_result.kernel; VLOG(6) << "scale API kernel key: [" << kernel_backend << ", " << kernel_layout << ", " << kernel_data_type << "]"; VLOG(6) << "scale API kernel: " << kernel; diff --git a/paddle/phi/tests/core/test_custom_kernel.cc b/paddle/phi/tests/core/test_custom_kernel.cc index abd77e286241096b4243807aaeac6e1da3162b36..8562a7930afd903f56727d2eb03060bc363c6370 100644 --- a/paddle/phi/tests/core/test_custom_kernel.cc +++ b/paddle/phi/tests/core/test_custom_kernel.cc @@ -192,8 +192,9 @@ TEST(CustomKernel, custom_kernel_dot) { fake_dot_kernels.end()); // 4.kernel select - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( op_name, phi::KernelKey(backend, layout, phi::DataType::UINT8)); + const auto& kernel = kernel_result.kernel; // 5.prepare parameters for kernel const auto alloc = std::make_unique(