broadcast_tensors_grad_kernel.cu 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/broadcast_tensors_grad_kernel.h"

17
#include <vector>
18

19 20 21 22 23
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
24
#include "paddle/phi/kernels/funcs/reduce_function.h"
25 26 27 28 29 30
#include "paddle/phi/kernels/primitive/functor_primitives.h"

namespace phi {

template <typename T, typename Context>
void BroadcastTensorsGradKernel(const Context& ctx,
31
                                const std::vector<const DenseTensor*>& dout,
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
                                std::vector<DenseTensor*> dx) {
  // Find reduce dimensions
  const auto& in_tensors = dout;
  auto& out_tensors = dx;

  size_t num_ins = in_tensors.size();

  PADDLE_ENFORCE_GT(
      num_ins,
      1,
      errors::InvalidArgument(
          "Expected at least 2 input tensors, but only received d%.",
          in_tensors.size()));

  PADDLE_ENFORCE_EQ(
      num_ins,
      out_tensors.size(),
      errors::InvalidArgument(
          "BroadcastTensorsOp expects equal number of inputs and outputs,"
          "but received: %d inputs v.s %d outputs",
          num_ins,
          out_tensors.size()));

  // For each In-Out tensor pair,
  // Prepare and apply broadcast dims array
  for (size_t i = 0; i < num_ins; i++) {
58
    auto* input_tensor = in_tensors[i];
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    auto* output_tensor = out_tensors[i];

    const DDim& input_dims = input_tensor->dims();
    const DDim& output_dims = output_tensor->dims();

    int in_rank = input_dims.size();
    int out_rank = output_dims.size();

    // Collect reduce_dims
    // Example:
    // dX  = [1,1,1,1]
    // dOut = [1,1,1,4]
    //
    // reduce_dims  = [3] // reduce along the broadcasted axis
    std::vector<int> reduce_dims_vec;
    for (int j = 0; j < in_rank; j++) {
      int out_axis = out_rank - j - 1;
      int in_axis = in_rank - j - 1;

      if (out_axis < 0 || output_dims[out_axis] != input_dims[in_axis]) {
        reduce_dims_vec.push_back(in_axis);
      }
    }

    bool just_copy = (reduce_dims_vec.size() == 0);
    ctx.template Alloc<T>(output_tensor);
    if (just_copy) {
      // Turns out to be a No-Op, simply copy tensors
      paddle::framework::TensorCopy(
          *input_tensor, ctx.GetPlace(), ctx, output_tensor);
    } else {
      // reduce_sum implementation on CUDA
91
      funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
92 93 94 95
          ctx,
          *input_tensor,
          output_tensor,
          kps::IdentityFunctor<T>(),
96
          reduce_dims_vec);
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    }
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(broadcast_tensors_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::BroadcastTensorsGradKernel,
                   int,
                   int64_t,
                   float,
                   double,
                   phi::dtype::float16) {}