未验证 提交 d5268a6e 编写于 作者: G Guoxia Wang 提交者: GitHub

fix bug of reduce_sum when src_dtype != dst_dtype and reduce_num == 1 (#36123)

上级 ad128144
......@@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
......@@ -705,8 +706,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
if (config.reduce_num == 1) {
auto out_dims = y->dims();
framework::TensorCopy(x, y->place(), y);
y->Resize(out_dims);
if (x.type() == y->type()) {
framework::TensorCopy(x, y->place(), y);
y->Resize(out_dims);
} else {
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(x.place()));
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(y->type()),
CastOpFunctor<platform::CUDADeviceContext, Tx>(&x, y, *dev_ctx));
}
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册