From ab0655630bb2db251dc28260fb335416a85406f5 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Tue, 2 Jun 2020 11:02:17 +0800 Subject: [PATCH] [CHERRY-PICK 1.8]fix bug that diag API can't use on Windows(#24825) * cherry-pick #24762 --- paddle/fluid/operators/trace_op.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/trace_op.h b/paddle/fluid/operators/trace_op.h index 726efb82dd8..51d807bfb3d 100644 --- a/paddle/fluid/operators/trace_op.h +++ b/paddle/fluid/operators/trace_op.h @@ -24,10 +24,10 @@ namespace paddle { namespace operators { template -struct DiagFunctor { - DiagFunctor(const T* input, const int64_t* diag_stride, - const int64_t* ret_strides, int64_t pos, int64_t dim_size, - T* diag) +struct DiagonalFunctor { + DiagonalFunctor(const T* input, const int64_t* diag_stride, + const int64_t* ret_strides, int64_t pos, int64_t dim_size, + T* diag) : input_(input), diag_stride_(diag_stride), ret_strides_(ret_strides), @@ -157,8 +157,8 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, auto& dev_ctx = context.template device_context(); platform::ForRange for_range(dev_ctx, diag.numel()); - DiagFunctor functor(input_data, diag_arr, ret_arr, pos, dim_size, - diag_data); + DiagonalFunctor functor(input_data, diag_arr, ret_arr, pos, dim_size, + diag_data); for_range(functor); return diag; } else { -- GitLab