From b9e4aaa5237ad00825aa740b983faf3ef2e30378 Mon Sep 17 00:00:00 2001 From: XiangGao Date: Fri, 2 Jul 2021 15:26:30 +0800 Subject: [PATCH] fix trace offset out of shape (#33922) --- paddle/fluid/operators/trace_op.cu | 4 ++++ paddle/fluid/operators/trace_op.h | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index 336c1c4083..f3fe32e10a 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -14,6 +14,7 @@ #include #include +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/trace_op.h" @@ -50,6 +51,9 @@ class TraceCUDAKernel : public framework::OpKernel { TensorReduce( diag, out, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); + } else { + math::SetConstant functor; + functor(context.device_context(), out, static_cast(0)); } } }; diff --git a/paddle/fluid/operators/trace_op.h b/paddle/fluid/operators/trace_op.h index b7a6e559ed..ca9439cbed 100644 --- a/paddle/fluid/operators/trace_op.h +++ b/paddle/fluid/operators/trace_op.h @@ -179,7 +179,7 @@ class TraceKernel : public framework::OpKernel { auto output_dims = out->dims(); - out->mutable_data(context.GetPlace()); + T* out_data = out->mutable_data(context.GetPlace()); const framework::Tensor diag = Diagonal(context, input, offset, dim1, dim2); @@ -191,6 +191,8 @@ class TraceKernel : public framework::OpKernel { auto reduce_dim = Eigen::array({1}); output.device(place) = x.sum(reduce_dim); out->Resize(output_dims); + } else { + std::fill(out_data, out_data + out->numel(), static_cast(0)); } } }; -- GitLab