diff --git a/paddle/phi/api/lib/kernel_dispatch.h b/paddle/phi/api/lib/kernel_dispatch.h index 1ca88acab8cc1574a7614abdd6dd06b51a2a084b..015c1be57370a2fc300cbd07be8a8dd619521aea 100644 --- a/paddle/phi/api/lib/kernel_dispatch.h +++ b/paddle/phi/api/lib/kernel_dispatch.h @@ -100,7 +100,9 @@ struct KernelKeyParser : ArgsIterator { key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(tensor); // TODO(chenweihang): select multi layout and dtype - key_set.layout = tensor.layout(); + phi::DataLayout tensor_layout = tensor.layout(); + key_set.layout = + tensor_layout > key_set.layout ? tensor_layout : key_set.layout; key_set.dtype = tensor.dtype(); dtype_set = dtype_set | DataTypeSet(key_set.dtype); auto promote_result = PromoteTypes(dtype_set);