未验证 提交 4140d7ec 编写于 作者: H HongyuJia 提交者: GitHub

[Fix KernelKeyParser] Unify the logic of `operator()` in `KernelKeyParser` (#46560)

* add datatype check for ParseKernelKeyByInputArgs

* polish error message

* Actually, einsum has vector<Tensor> inpute with DataType::COMPLEX64, see test_einsum_v2.py

* headerfile remove enforce.h
上级 3e0a1765
...@@ -117,12 +117,10 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -117,12 +117,10 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
} }
void operator()(const std::vector<Tensor>& x) { void operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
const phi::TensorBase& tensor = *x.at(0).impl(); const phi::TensorBase& tensor = *x.at(0).impl();
key_set.backend_set = AssignKernelKeySet(tensor);
key_set.backend_set | detail::GetTensorBackendSet(tensor); }
// TODO(chenweihang): select multi layout and dtype
key_set.layout = tensor.layout();
key_set.dtype = tensor.dtype();
} }
void operator()(const paddle::optional<Tensor>& x) { void operator()(const paddle::optional<Tensor>& x) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册