未验证 提交 e0fd3bbf 编写于 作者: C chentianyu03 提交者: GitHub

[pten] fix when out_dtype is same with x.dtype and still transform type error (#38285)

* fix when out_dtype is same with x.dtype and still transform type error

* fix spell error
上级 643a268e
......@@ -61,7 +61,7 @@ void Reduce(const CUDAContext& dev_ctx,
gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED) {
if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) {
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
out_dtype, "TensorReduceFunctorImpl", ([&] {
pten::detail::TensorReduceFunctorImpl<T, data_t, ReduceFunctor>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册