diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index 336c1c40832b97e504710c9721a94fbac1cd447d..f3fe32e10a52b6fcc8bbae9f8f1b9ab4a104d8b2 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -14,6 +14,7 @@ #include #include +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/trace_op.h" @@ -50,6 +51,9 @@ class TraceCUDAKernel : public framework::OpKernel { TensorReduce( diag, out, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); + } else { + math::SetConstant functor; + functor(context.device_context(), out, static_cast(0)); } } }; diff --git a/paddle/fluid/operators/trace_op.h b/paddle/fluid/operators/trace_op.h index b7a6e559ed4ef6ee4cd43b9375b3531488db449d..ca9439cbed97ddb02e2e6eaa2fb89628e738576e 100644 --- a/paddle/fluid/operators/trace_op.h +++ b/paddle/fluid/operators/trace_op.h @@ -179,7 +179,7 @@ class TraceKernel : public framework::OpKernel { auto output_dims = out->dims(); - out->mutable_data(context.GetPlace()); + T* out_data = out->mutable_data(context.GetPlace()); const framework::Tensor diag = Diagonal(context, input, offset, dim1, dim2); @@ -191,6 +191,8 @@ class TraceKernel : public framework::OpKernel { auto reduce_dim = Eigen::array({1}); output.device(place) = x.sum(reduce_dim); out->Resize(output_dims); + } else { + std::fill(out_data, out_data + out->numel(), static_cast(0)); } } };