diff --git a/paddle/pten/api/lib/kernel_dispatch.h b/paddle/pten/api/lib/kernel_dispatch.h index de753669af667978f801922344ff4ab0645ab119..17ef22cb6dea9c629fc3845eb9f770cc85272539 100644 --- a/paddle/pten/api/lib/kernel_dispatch.h +++ b/paddle/pten/api/lib/kernel_dispatch.h @@ -46,8 +46,6 @@ enum class KernelType { // TODO(chenweihang): support DataLayout and DataType selected struct KernelKeySet { - KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL}; - BackendSet backend_set{Backend::UNDEFINED}; DataLayout layout{DataLayout::UNDEFINED}; DataType dtype{DataType::UNDEFINED}; @@ -97,9 +95,6 @@ struct KernelKeyParser : ArgsIterator { void operator()(const Tensor& x) { key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x); // TODO(chenweihang): selecte multi layout and dtype - if (pten::SelectedRows::classof(x.impl().get())) { - key_set.kernel_type = KernelType::SELECTED_ROWS_KENREL; - } key_set.layout = x.layout(); key_set.dtype = x.type(); dtype_set = dtype_set | DataTypeSet(x.dtype()); @@ -124,6 +119,24 @@ struct KernelKeyParser : ArgsIterator { } }; +struct KernelTypeParser : ArgsIterator { + KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL}; + + // TODO(chenweihang): deal with multiple diff input Tensors + // TODO(chenweihang): add global device guard method to set backend + void operator()(const Tensor& x) { + if (pten::SelectedRows::classof(x.impl().get())) { + kernel_type = KernelType::SELECTED_ROWS_KENREL; + } + } + + // skip other type args, these args don't used in kernel selection + template + void operator()(const T& x) { + // do nothing + } +}; + } // namespace detail template @@ -131,6 +144,11 @@ KernelKeySet ParseKernelKeyByInputArgs(const Args&... args) { return detail::KernelKeyParser().apply(args...).key_set; } +template +KernelType ParseKernelTypeByInputArgs(const Args&... args) { + return detail::KernelTypeParser().apply(args...).kernel_type; +} + DataType ParseDataType(DataType dtype); DataType ParseDataType(const Tensor& tensor); DataType ParseDataType(const std::vector& tensors); diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 73c3ba4e4b4fe9d56ac6e6c7638777fc5df89164..7515981490728774de215c6c8698f96884cd946a 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -345,7 +345,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare if len(input_names) > 0: if self.support_selected_rows_kernel: kernel_select_code = kernel_select_code + f""" - KernelType kernel_type; + KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)}); """ kernel_select_code = kernel_select_code + f""" @@ -354,7 +354,6 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare || kernel_data_type == DataType::UNDEFINED ) {{ auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); - {'kernel_type = kernel_key_set.kernel_type;' if self.support_selected_rows_kernel else ''} if (kernel_backend == Backend::UNDEFINED) {{ kernel_backend = kernel_key.backend(); }}