未验证 提交 b72d4cb4 编写于 作者: Z zyfncg 提交者: GitHub

fix selected_rows bug in C++ API (#39658)

上级 a1ad003c
......@@ -46,8 +46,6 @@ enum class KernelType {
// TODO(chenweihang): support DataLayout and DataType selected
struct KernelKeySet {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};
BackendSet backend_set{Backend::UNDEFINED};
DataLayout layout{DataLayout::UNDEFINED};
DataType dtype{DataType::UNDEFINED};
......@@ -97,9 +95,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
void operator()(const Tensor& x) {
key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x);
// TODO(chenweihang): selecte multi layout and dtype
if (pten::SelectedRows::classof(x.impl().get())) {
key_set.kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
key_set.layout = x.layout();
key_set.dtype = x.type();
dtype_set = dtype_set | DataTypeSet(x.dtype());
......@@ -124,6 +119,24 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
}
};
struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};
// TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend
void operator()(const Tensor& x) {
if (pten::SelectedRows::classof(x.impl().get())) {
kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
}
// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};
} // namespace detail
template <typename... Args>
......@@ -131,6 +144,11 @@ KernelKeySet ParseKernelKeyByInputArgs(const Args&... args) {
return detail::KernelKeyParser().apply(args...).key_set;
}
template <typename... Args>
KernelType ParseKernelTypeByInputArgs(const Args&... args) {
return detail::KernelTypeParser().apply(args...).kernel_type;
}
DataType ParseDataType(DataType dtype);
DataType ParseDataType(const Tensor& tensor);
DataType ParseDataType(const std::vector<Tensor>& tensors);
......
......@@ -345,7 +345,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
if len(input_names) > 0:
if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f"""
KernelType kernel_type;
KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)});
"""
kernel_select_code = kernel_select_code + f"""
......@@ -354,7 +354,6 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
{'kernel_type = kernel_key_set.kernel_type;' if self.support_selected_rows_kernel else ''}
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册