reduce_op.h 27.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14 15 16

#pragma once

17
#include <algorithm>
18
#include <set>
19
#include <string>
W
whs 已提交
20
#include <vector>
21
#include "paddle/fluid/framework/data_type_transform.h"
22
#include "paddle/fluid/framework/tensor_util.h"
23
#include "paddle/fluid/operators/cast_op.h"
24
#include "paddle/fluid/operators/math/math_function.h"
W
Wu Yi 已提交
25
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
26 27 28 29 30 31 32

// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
#include "paddle/pten/kernels/functions/general/reduce_impl.h"

33 34 35
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
G
guosheng 已提交
36 37 38 39

namespace paddle {
namespace operators {

40 41
#define HANDLE_DIM(NDIM, RDIM)                                            \
  if (ndim == NDIM && rdim == RDIM) {                                     \
42
    ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, Functor>(              \
43 44
        context.template device_context<DeviceContext>(), *input, output, \
        dims, keep_dim);                                                  \
W
whs 已提交
45 46
  }

47
using Tensor = framework::Tensor;
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
using DDim = framework::DDim;

inline void GetShuffledDim(const DDim& src_dims, DDim* dst_dims,
                           const std::vector<int>& reduced_dims,
                           std::vector<int>* perm_axis) {
  // check if it's a reduced dim
  std::vector<bool> src_dims_check(src_dims.size(), false);
  size_t src_size = src_dims.size();
  size_t reduce_size = reduced_dims.size();
  for (size_t i = 0; i < reduce_size; ++i) {
    dst_dims->at(src_size - reduce_size + i) = src_dims[reduced_dims[i]];
    (*perm_axis)[src_size - reduce_size + i] = reduced_dims[i];
    src_dims_check[reduced_dims[i]] = true;
  }

  size_t offset = 0;
  for (size_t i = 0; i < src_dims_check.size(); ++i) {
    bool is_reduced = src_dims_check[i];
    if (!is_reduced) {
      (*perm_axis)[offset] = i;
      dst_dims->at(offset++) = src_dims[i];
    }
  }
}

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
static inline std::vector<int> GetReduceDim(const std::vector<int>& dims,
                                            int dim_size, bool reduce_all) {
  std::vector<int> reduce_dims;
  if (reduce_all) {
    reduce_dims.resize(dim_size);
    int reduce_size = reduce_dims.size();
    for (int i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = i;
    }
  } else {
    for (auto e : dims) {
      PADDLE_ENFORCE_LT(e, dim_size,
                        paddle::platform::errors::InvalidArgument(
                            "ReduceOp: invalid axis, when x_dims is %d, "
                            "axis[i] should less than x_dims, but got %d.",
                            dim_size, e));
      reduce_dims.push_back(e >= 0 ? e : e + dim_size);
    }
  }
  return reduce_dims;
}
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <typename DeviceContext, typename OutT>
void GetShuffledInput(const framework::ExecutionContext& context,
                      const Tensor* input, Tensor* shuffled_input,
                      const std::vector<int>& dims) {
  DDim shuffled_dims(input->dims());
  std::vector<int> perm_axis(input->dims().size());
  GetShuffledDim(input->dims(), &shuffled_dims, dims, &perm_axis);

  shuffled_input->Resize(shuffled_dims);
  shuffled_input->mutable_data<OutT>(context.GetPlace());

  math::TransposeNormal<DeviceContext, OutT> trans;
  trans(context.template device_context<DeviceContext>(), *input,
        shuffled_input, perm_axis);
}

inline void GetOriginDimFromShuffled(const DDim& src_dim,
                                     const std::vector<int>& dims,
                                     std::vector<int>* origin_dim) {
  DDim shuffled_dims(src_dim);
  size_t n = src_dim.size();
  std::vector<int> perm_axis(n);
  GetShuffledDim(src_dim, &shuffled_dims, dims, &perm_axis);
  for (size_t i = 0; i < n; ++i) {
    (*origin_dim)[perm_axis[i]] = i;
  }
}

template <typename DeviceContext, typename OutT, typename Functor>
void HandleLargeDim(const framework::ExecutionContext& context,
                    const Tensor* input, Tensor* output,
                    const std::vector<int>& dims, bool keep_dim) {
  //  shuffle the reduced dim to the end
  Tensor shuffled_input;
  GetShuffledInput<DeviceContext, OutT>(context, input, &shuffled_input, dims);

  // transpose to 2D tensor whose shape is {unreduced, reduced}.
  const int64_t unreduced = output->numel();
  const int64_t reduced = shuffled_input.numel() / unreduced;
  shuffled_input.Resize({unreduced, reduced});
  DDim output_dim = output->dims();
  output->Resize({unreduced});
  ReduceFunctor<DeviceContext, OutT, 2, 1, Functor>(
      context.template device_context<DeviceContext>(), shuffled_input, output,
      {1}, keep_dim);
  output->Resize(output_dim);
}

template <typename DeviceContext, typename T, typename Functor>
void HandleLargeDimGrad(const framework::ExecutionContext& context,
                        const framework::Tensor* x,
                        const framework::Tensor* out,
                        const framework::Tensor* dout, framework::Tensor* dx,
                        const std::vector<int>& dims) {
  const int64_t unreduced = out->numel();
  const int64_t reduced = x->numel() / unreduced;
  DDim out_dim(out->dims());
  DDim x_dim(x->dims());
  // transpose and reshape X
  Tensor shuffled_x;
  GetShuffledInput<DeviceContext, T>(context, x, &shuffled_x, dims);
  DDim shuffled_dim = shuffled_x.dims();
  shuffled_x.Resize({unreduced, reduced});
  // reshape dX {unreduced, reduced}
  dx->Resize({unreduced, reduced});
  ReduceGradFunctor<DeviceContext, T, 2, Functor>(
      context.template device_context<DeviceContext>(), shuffled_x, *out, *dout,
      dx, {1});
  // transpose dX
  std::vector<int> origin_axis(x_dim.size());
  GetOriginDimFromShuffled(x_dim, dims, &origin_axis);
  Tensor dx_tmp;
  framework::TensorCopy(*dx, context.GetPlace(), &dx_tmp);
  dx_tmp.Resize(shuffled_dim);
  dx->Resize(x_dim);
  math::TransposeNormal<DeviceContext, T> trans;
  trans(context.template device_context<DeviceContext>(), dx_tmp, dx,
        origin_axis);
}
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207

template <typename DeviceContext, typename T, typename Functor>
struct ReduceKernelFunctor {
  const Tensor* input;
  Tensor* output;
  std::vector<int> dims;
  bool keep_dim;
  bool reduce_all;
  const framework::ExecutionContext& context;
  ReduceKernelFunctor(const Tensor* input, Tensor* output,
                      const std::vector<int>& dims, bool keep_dim,
                      bool reduce_all,
                      const framework::ExecutionContext& context)
      : input(input),
        output(output),
        dims(dims),
        keep_dim(keep_dim),
        reduce_all(reduce_all),
        context(context) {}

  template <typename OutT>
  void apply() const {
    output->mutable_data<OutT>(context.GetPlace());
    if (reduce_all) {
      // Flatten and reduce 1-D tensor
      auto x = EigenVector<OutT>::Flatten(*input);
      auto out = EigenScalar<OutT>::From(*output);
      auto& place =
          *context.template device_context<DeviceContext>().eigen_device();
      auto reduce_dim = Eigen::array<int, 1>({{0}});
      Functor functor;
      functor(place, &x, &out, reduce_dim);
    } else {
      int ndim = input->dims().size();
      int rdim = dims.size();
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
      if (ndim > 6) {
        HandleLargeDim<DeviceContext, OutT, Functor>(context, input, output,
                                                     dims, keep_dim);
      } else {
        HANDLE_DIM(6, 5);
        HANDLE_DIM(6, 4);
        HANDLE_DIM(6, 3);
        HANDLE_DIM(6, 2);
        HANDLE_DIM(6, 1);
        HANDLE_DIM(5, 4);
        HANDLE_DIM(5, 3);
        HANDLE_DIM(5, 2);
        HANDLE_DIM(5, 1);
        HANDLE_DIM(4, 3);
        HANDLE_DIM(4, 2);
        HANDLE_DIM(4, 1);
        HANDLE_DIM(3, 2);
        HANDLE_DIM(3, 1);
        HANDLE_DIM(2, 1);
        HANDLE_DIM(1, 1);
      }
229 230 231
    }
  }
};
Q
QI JUN 已提交
232
template <typename DeviceContext, typename T, typename Functor>
Y
Yu Yang 已提交
233
class ReduceKernel : public framework::OpKernel<T> {
234 235 236 237 238 239 240 241
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    bool reduce_all = context.Attr<bool>("reduce_all");
    auto* output = context.Output<Tensor>("Out");
    auto dims = context.Attr<std::vector<int>>("dim");
    bool keep_dim = context.Attr<bool>("keep_dim");
    int out_dtype = context.Attr<int>("out_dtype");
    framework::proto::VarType::Type cast_out_dtype;
242
    auto* input = context.Input<Tensor>("X");
243

244 245
    if (out_dtype < 0) {
      cast_out_dtype =
246
          static_cast<framework::proto::VarType::Type>(input->type());
247 248 249
    } else {
      cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
    }
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

    auto& dev_ctx = context.device_context<DeviceContext>();
    output->mutable_data(
        dev_ctx.GetPlace(),
        static_cast<framework::proto::VarType::Type>(cast_out_dtype));

    auto pt_x = paddle::experimental::MakePtenDenseTensor(*input);
    auto pt_out = paddle::experimental::MakePtenDenseTensor(*output);

    std::vector<int64_t> tmp_dims(dims.begin(), dims.end());

    // call new kernel
    pten::general::Reduce<DeviceContext, T, Functor>(
        dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim,
        pten::TransToPtenDataType(cast_out_dtype), pt_out.get());
265 266 267 268
  }
};
template <typename DeviceContext, typename OutT, typename Functor>
class BoolReduceKernel : public framework::OpKernel<OutT> {
G
guosheng 已提交
269 270
 public:
  void Compute(const framework::ExecutionContext& context) const override {
271
    bool reduce_all = context.Attr<bool>("reduce_all");
272 273
    auto* input = context.Input<Tensor>("X");
    auto* output = context.Output<Tensor>("Out");
274
    output->mutable_data<OutT>(context.GetPlace());
275 276 277 278

    auto dims = context.Attr<std::vector<int>>("dim");
    bool keep_dim = context.Attr<bool>("keep_dim");

279 280 281 282 283 284 285 286 287 288 289 290
    // The dims has full dim, set the reduce_all is True
    const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
    std::set<int> dims_set(dims.begin(), dims.end());
    bool full_dim = true;
    for (auto i = 0; i < input_dim_size; i++) {
      if (dims_set.find(i) == dims_set.end()) {
        full_dim = false;
        break;
      }
    }
    reduce_all = (reduce_all || full_dim);

291 292
    if (reduce_all) {
      // Flatten and reduce 1-D tensor
293 294
      auto x = EigenVector<OutT>::Flatten(*input);
      auto out = EigenScalar<OutT>::From(*output);
295 296 297 298
      auto& place =
          *context.template device_context<DeviceContext>().eigen_device();
      auto reduce_dim = Eigen::array<int, 1>({{0}});
      Functor functor;
299
      functor(place, &x, &out, reduce_dim);
300
    } else {
301 302
      int ndim = input->dims().size();
      int rdim = dims.size();
303
      // comments for accelerating compiling temporarily.
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
      if (ndim > 6) {
        HandleLargeDim<DeviceContext, OutT, Functor>(context, input, output,
                                                     dims, keep_dim);
      } else {
        HANDLE_DIM(6, 5);
        HANDLE_DIM(6, 4);
        HANDLE_DIM(6, 3);
        HANDLE_DIM(6, 2);
        HANDLE_DIM(6, 1);
        HANDLE_DIM(5, 4);
        HANDLE_DIM(5, 3);
        HANDLE_DIM(5, 2);
        HANDLE_DIM(5, 1);
        HANDLE_DIM(4, 3);
        HANDLE_DIM(4, 2);
        HANDLE_DIM(4, 1);
        HANDLE_DIM(3, 2);
        HANDLE_DIM(3, 1);
        HANDLE_DIM(2, 1);
        HANDLE_DIM(1, 1);
      }
G
guosheng 已提交
325 326 327
    }
  }
};
328

329 330
template <typename DeviceContext, typename T, typename Functor,
          bool kNoNeedBufferX = false, bool kNoNeedBufferY = false>
Y
Yu Yang 已提交
331
class ReduceGradKernel : public framework::OpKernel<T> {
G
guosheng 已提交
332
 public:
333 334
  void ComputeFromInput(const Tensor* input2,
                        const framework::ExecutionContext& context) const {
335
    bool reduce_all = context.Attr<bool>("reduce_all");
336 337 338
    auto dims = context.Attr<std::vector<int>>("dim");
    auto* input0 = context.Input<Tensor>("X");
    auto* input1 = context.Input<Tensor>("Out");
339

340 341 342
    auto* output = context.Output<Tensor>(framework::GradVarName("X"));
    output->mutable_data<T>(context.GetPlace());

343 344 345 346 347 348 349 350 351 352 353
    // The dims has full dim, set the reduce_all is True
    const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
    std::set<int> dims_set(dims.begin(), dims.end());
    bool full_dim = true;
    for (auto i = 0; i < input_dim_size; i++) {
      if (dims_set.find(i) == dims_set.end()) {
        full_dim = false;
        break;
      }
    }
    reduce_all = (reduce_all || full_dim);
354 355 356 357 358 359 360 361 362 363 364
    // NOTE: EigenTensor::From() uses tensor->data()
    // if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
    // kNoNeedBufferY should set true
    // and use fake var that has same dims.
    if (kNoNeedBufferX) {
      input0 = output;
    }
    if (kNoNeedBufferY) {
      input1 = input2;
    }

L
lvmengsi 已提交
365 366 367 368
    // NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
    // not be set as Input in grad Maker, use Out_grad to replace here
    if (!input1) input1 = input2;

369 370
    if (reduce_all) {
      auto x = EigenVector<T>::Flatten(*input0);
371 372
      auto x_reduce = EigenVector<T>::Flatten(*input1);
      auto x_reduce_grad = EigenVector<T>::Flatten(*input2);
373 374 375 376 377 378
      auto x_grad = EigenVector<T>::Flatten(*output);
      auto& place =
          *context.template device_context<DeviceContext>().eigen_device();
      auto broadcast_dim =
          Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
      Functor functor;
379
      functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
380 381
              broadcast_dim[0]);
    } else {
382
      int rank = input0->dims().size();
383 384
      switch (rank) {
        case 1:
385 386 387
          ReduceGradFunctor<DeviceContext, T, 1, Functor>(
              context.template device_context<DeviceContext>(), *input0,
              *input1, *input2, output, dims);
388 389
          break;
        case 2:
390 391 392
          ReduceGradFunctor<DeviceContext, T, 2, Functor>(
              context.template device_context<DeviceContext>(), *input0,
              *input1, *input2, output, dims);
393 394
          break;
        case 3:
395 396 397
          ReduceGradFunctor<DeviceContext, T, 3, Functor>(
              context.template device_context<DeviceContext>(), *input0,
              *input1, *input2, output, dims);
398 399
          break;
        case 4:
400 401 402
          ReduceGradFunctor<DeviceContext, T, 4, Functor>(
              context.template device_context<DeviceContext>(), *input0,
              *input1, *input2, output, dims);
403 404
          break;
        case 5:
405 406 407
          ReduceGradFunctor<DeviceContext, T, 5, Functor>(
              context.template device_context<DeviceContext>(), *input0,
              *input1, *input2, output, dims);
408 409
          break;
        case 6:
410 411 412
          ReduceGradFunctor<DeviceContext, T, 6, Functor>(
              context.template device_context<DeviceContext>(), *input0,
              *input1, *input2, output, dims);
413
          break;
414 415 416 417
        default:
          HandleLargeDimGrad<DeviceContext, T, Functor>(context, input0, input1,
                                                        input2, output, dims);
          break;
418
      }
G
guosheng 已提交
419 420
    }
  }
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440

  void Compute(const framework::ExecutionContext& context) const override {
    int in_dtype = context.Attr<int>("in_dtype");
    if (in_dtype >= 0) {
      Tensor tmp_tensor;
      auto* pre_input = context.Input<Tensor>(framework::GradVarName("Out"));
      auto in_kernel_type =
          framework::OpKernelType(pre_input->type(), context.GetPlace());
      auto out_kernel_type = framework::OpKernelType(
          static_cast<framework::proto::VarType::Type>(in_dtype),
          context.GetPlace());
      framework::TransDataType(in_kernel_type, out_kernel_type, *pre_input,
                               &tmp_tensor);
      ComputeFromInput(&tmp_tensor, context);

    } else {
      auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
      ComputeFromInput(input2, context);
    }
  }
441
};
G
guosheng 已提交
442

443 444 445
class ReduceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
G
guosheng 已提交
446

447
  void InferShape(framework::InferShapeContext* ctx) const override {
448 449
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp");
450 451 452
    auto x_dims = ctx->GetInputDim("X");
    auto x_rank = x_dims.size();
    auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
453 454 455 456 457 458
    PADDLE_ENFORCE_GT(dims.size(), 0,
                      platform::errors::InvalidArgument(
                          "The input dim dimensions of ReduceOp "
                          "should be greater than 0. But received the dim "
                          "dimesions of Reduce = %d.",
                          dims.size()));
459

460
    for (size_t i = 0; i < dims.size(); ++i) {
461
      PADDLE_ENFORCE_LT(dims[i], x_rank,
462 463 464 465 466
                        platform::errors::InvalidArgument(
                            "The reduce dim index %d should be in the "
                            "range [-dimension(X), dimension(X)] "
                            "which dimesion = %d. But received dim index = %d.",
                            i, x_rank, dims[i]));
467 468 469 470 471 472
      PADDLE_ENFORCE_GE(dims[i], -x_rank,
                        platform::errors::InvalidArgument(
                            "The reduce dim index %d should be in the "
                            "range [-dimension(X), dimension(X)] "
                            "which dimesion = %d. But received dim index = %d.",
                            i, x_rank, dims[i]));
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
      if (dims[i] < 0) dims[i] = x_rank + dims[i];
    }
    sort(dims.begin(), dims.end());
    bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
    bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
    if (reduce_all) {
      if (keep_dim)
        ctx->SetOutputDim(
            "Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
      else
        ctx->SetOutputDim("Out", {1});
    } else {
      auto dims_vector = vectorize(x_dims);
      if (keep_dim) {
        for (size_t i = 0; i < dims.size(); ++i) {
          dims_vector[dims[i]] = 1;
        }
      } else {
        const int kDelFlag = -2;
        for (size_t i = 0; i < dims.size(); ++i) {
          dims_vector[dims[i]] = kDelFlag;
        }
        dims_vector.erase(
            remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
            dims_vector.end());
      }
499 500 501
      if (!keep_dim && dims_vector.size() == 0) {
        dims_vector.push_back(1);
      }
502 503
      auto out_dims = framework::make_ddim(dims_vector);
      ctx->SetOutputDim("Out", out_dims);
504
      if (dims.size() > 0 && dims[0] != 0) {
505 506 507 508 509
        // Only pass LoD when not reducing on the first dim.
        ctx->ShareLoD("X", /*->*/ "Out");
      }
    }
  }
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    // choose cudnn kernel if the runtime supported.
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

    if (ctx.Input<paddle::framework::LoDTensor>("X")->dims().size() > 5)
      return framework::OpKernelType(input_data_type, ctx.GetPlace());

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif

    if (input_data_type == framework::proto::VarType::FP16) {
F
furnace 已提交
528 529 530
      PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) ||
                            platform::is_npu_place(ctx.GetPlace()),
                        true,
531
                        platform::errors::InvalidArgument(
F
furnace 已提交
532
                            "float16 can only be used on GPU or NPU place"));
533 534 535
    }
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
536 537
};

G
Guo Sheng 已提交
538 539 540 541 542 543 544 545 546 547 548 549 550
class ReduceOpUseInputPlace : public ReduceOp {
 public:
  using ReduceOp::ReduceOp;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
    kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
    return kt;
  }
};

551 552 553
class ReduceGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
W
whs 已提交
554

555
  void InferShape(framework::InferShapeContext* ctx) const override {
556 557 558
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "ReduceOp");
559 560 561
    auto x_dims = ctx->GetInputDim("X");
    auto x_rank = x_dims.size();
    auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
W
whs 已提交
562
    for (size_t i = 0; i < dims.size(); ++i) {
563
      PADDLE_ENFORCE_LT(dims[i], x_rank,
564 565 566 567 568
                        platform::errors::InvalidArgument(
                            "The reduce dim index %d should be in the "
                            "range [-dimension(X), dimension(X)], "
                            "which dimesion = %d. But received dim index = %d.",
                            i, x_rank, dims[i]));
W
whs 已提交
569
      if (dims[i] < 0) dims[i] = x_rank + dims[i];
570 571 572 573 574 575
    }
    sort(dims.begin(), dims.end());
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
      ctx->ShareLoD("X", /*->*/ x_grad_name);
W
whs 已提交
576
    }
577
  }
578 579 580 581

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
J
jakpiase 已提交
582 583 584 585 586
    int in_dtype = ctx.Attr<int>("in_dtype");
    auto input_data_type =
        (in_dtype >= 0) ? static_cast<framework::proto::VarType::Type>(in_dtype)
                        : OperatorWithKernel::IndicateVarDataType(
                              ctx, framework::GradVarName("Out"));
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
#ifdef PADDLE_WITH_MKLDNN
    auto CanMKLDNNReduceGradBeUsed = [&]() {
      auto dx_dims = ctx.Input<Tensor>("X")->dims();

      if (dx_dims.size() > 5) return false;  // max 5D tensor is supported

      return true;
    };
    if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
        CanMKLDNNReduceGradBeUsed()) {
      return framework::OpKernelType(input_data_type, ctx.GetPlace(),
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif

    return framework::OpKernelType(input_data_type, ctx.GetPlace());
604
  }
605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
};

class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() final {
    AddInput("X",
             "(Tensor) The input tensor. Tensors with rank at most 6 are "
             "supported.");
    AddOutput("Out", "(Tensor) The result tensor.");
    AddAttr<std::vector<int>>(
        "dim",
        "(list<int>, default {0}) The dimensions to reduce. "
        "Must be in the range [-rank(input), rank(input)). "
        "If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. "
        "Note that reducing on the first dim will make the LoD info lost.")
        .SetDefault({0});
    AddAttr<bool>("keep_dim",
                  "(bool, default false) "
                  "If true, retain the reduced dimension with length 1.")
        .SetDefault(false);
    AddAttr<bool>("reduce_all",
                  "(bool, default false) "
                  "If true, output a scalar reduced along all dimensions.")
        .SetDefault(false);
629 630 631 632 633 634 635 636 637 638
    AddAttr<int>("in_dtype",
                 "(int, default -1)"
                 "The dtype of input, default value is -1, the user could not "
                 "set this value.")
        .SetDefault(-1);
    AddAttr<int>(
        "out_dtype",
        "(int, default -1)"
        "The dtype of output, default value is -1, the dtype is same as intput")
        .SetDefault(-1);
639 640
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
641 642
        .SetDefault(false)
        .AsExtra();
643 644
    AddComment(string::Sprintf(R"DOC(
%s Operator.
W
whs 已提交
645

646 647 648
This operator computes the %s of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
W
whs 已提交
649

650 651
)DOC",
                               GetOpType(), GetName()));
G
guosheng 已提交
652
  }
653 654 655 656

 protected:
  virtual std::string GetName() const = 0;
  virtual std::string GetOpType() const = 0;
G
guosheng 已提交
657 658
};

659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
#if defined(__HIPCC__) || defined(__NVCC__)
template <typename T, template <typename, typename> class ReduceOp>
class ReduceCudaKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    bool reduce_all = context.Attr<bool>("reduce_all");
    const Tensor* input = context.Input<Tensor>("X");
    Tensor* output = context.Output<Tensor>("Out");
    auto out_dtype = context.Attr<int>("out_dtype");
    std::vector<int> dims = context.Attr<std::vector<int>>("dim");

    std::vector<int> reduce_dims =
        GetReduceDim(dims, input->dims().size(), reduce_all);

    gpuStream_t stream = context.cuda_device_context().stream();
    if (out_dtype >= 0) {
      framework::VisitDataTypeSmall(
          static_cast<framework::proto::VarType::Type>(out_dtype),
          TensorReduceFunc<T, ReduceOp>(*input, output, reduce_dims, stream));
    } else {
      TensorReduceFunctorImpl<T, T, ReduceOp>(*input, output, reduce_dims,
                                              stream);
    }
  }
};
#endif

G
guosheng 已提交
686 687
}  // namespace operators
}  // namespace paddle
688

689 690
namespace ops = paddle::operators;

H
hong 已提交
691 692 693 694 695 696 697 698 699 700 701 702 703 704
#define REGISTER_REDUCE_OP(op_name)                                           \
  class __##op_name##Maker__ : public ops::ReduceOpMaker {                    \
   protected:                                                                 \
    virtual std::string GetName() const { return #op_name; }                  \
    virtual std::string GetOpType() const { return "Reduce " #op_name; }      \
  };                                                                          \
  REGISTER_OPERATOR(                                                          \
      op_name, ops::ReduceOp, __##op_name##Maker__,                           \
      paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, \
      paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase,       \
                                            true>);                           \
  REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp)

#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name, ...)                    \
705 706 707 708 709
  class __##op_name##Maker__ : public ops::ReduceOpMaker {               \
   protected:                                                            \
    virtual std::string GetName() const { return #op_name; }             \
    virtual std::string GetOpType() const { return "Reduce " #op_name; } \
  };                                                                     \
H
hong 已提交
710 711 712 713
  REGISTER_OPERATOR(                                                     \
      op_name, ops::ReduceOp##__VA_ARGS__, __##op_name##Maker__,         \
      paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,    \
      paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);