未验证 提交 4140d7ec 编写于 作者: H HongyuJia 提交者: GitHub

[Fix KernelKeyParser] Unify the logic of `operator()` in `KernelKeyParser` (#46560)

* add datatype check for ParseKernelKeyByInputArgs

* polish error message

* Actually, einsum has vector<Tensor> inpute with DataType::COMPLEX64, see test_einsum_v2.py

* headerfile remove enforce.h
上级 3e0a1765
......@@ -117,12 +117,10 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
}
void operator()(const std::vector<Tensor>& x) {
const phi::TensorBase& tensor = *x.at(0).impl();
key_set.backend_set =
key_set.backend_set | detail::GetTensorBackendSet(tensor);
// TODO(chenweihang): select multi layout and dtype
key_set.layout = tensor.layout();
key_set.dtype = tensor.dtype();
if (!x.empty()) {
const phi::TensorBase& tensor = *x.at(0).impl();
AssignKernelKeySet(tensor);
}
}
void operator()(const paddle::optional<Tensor>& x) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册