未验证 提交 b72d4cb4 编写于 作者: Z zyfncg 提交者: GitHub

fix selected_rows bug in C++ API (#39658)

上级 a1ad003c
...@@ -46,8 +46,6 @@ enum class KernelType { ...@@ -46,8 +46,6 @@ enum class KernelType {
// TODO(chenweihang): support DataLayout and DataType selected // TODO(chenweihang): support DataLayout and DataType selected
struct KernelKeySet { struct KernelKeySet {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};
BackendSet backend_set{Backend::UNDEFINED}; BackendSet backend_set{Backend::UNDEFINED};
DataLayout layout{DataLayout::UNDEFINED}; DataLayout layout{DataLayout::UNDEFINED};
DataType dtype{DataType::UNDEFINED}; DataType dtype{DataType::UNDEFINED};
...@@ -97,9 +95,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -97,9 +95,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
void operator()(const Tensor& x) { void operator()(const Tensor& x) {
key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x); key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x);
// TODO(chenweihang): selecte multi layout and dtype // TODO(chenweihang): selecte multi layout and dtype
if (pten::SelectedRows::classof(x.impl().get())) {
key_set.kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
key_set.layout = x.layout(); key_set.layout = x.layout();
key_set.dtype = x.type(); key_set.dtype = x.type();
dtype_set = dtype_set | DataTypeSet(x.dtype()); dtype_set = dtype_set | DataTypeSet(x.dtype());
...@@ -124,6 +119,24 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -124,6 +119,24 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
} }
}; };
struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};
// TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend
void operator()(const Tensor& x) {
if (pten::SelectedRows::classof(x.impl().get())) {
kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
}
// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};
} // namespace detail } // namespace detail
template <typename... Args> template <typename... Args>
...@@ -131,6 +144,11 @@ KernelKeySet ParseKernelKeyByInputArgs(const Args&... args) { ...@@ -131,6 +144,11 @@ KernelKeySet ParseKernelKeyByInputArgs(const Args&... args) {
return detail::KernelKeyParser().apply(args...).key_set; return detail::KernelKeyParser().apply(args...).key_set;
} }
template <typename... Args>
KernelType ParseKernelTypeByInputArgs(const Args&... args) {
return detail::KernelTypeParser().apply(args...).kernel_type;
}
DataType ParseDataType(DataType dtype); DataType ParseDataType(DataType dtype);
DataType ParseDataType(const Tensor& tensor); DataType ParseDataType(const Tensor& tensor);
DataType ParseDataType(const std::vector<Tensor>& tensors); DataType ParseDataType(const std::vector<Tensor>& tensors);
......
...@@ -345,7 +345,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare ...@@ -345,7 +345,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
if len(input_names) > 0: if len(input_names) > 0:
if self.support_selected_rows_kernel: if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
KernelType kernel_type; KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)});
""" """
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
...@@ -354,7 +354,6 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare ...@@ -354,7 +354,6 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
|| kernel_data_type == DataType::UNDEFINED ) {{ || kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
{'kernel_type = kernel_key_set.kernel_type;' if self.support_selected_rows_kernel else ''}
if (kernel_backend == Backend::UNDEFINED) {{ if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend(); kernel_backend = kernel_key.backend();
}} }}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册