“9c9c7437ab0fec0197b700fcbdea173731fbeb82”上不存在“07.machine_translation/index.en.html”
dist_grad_kernel.cc 3.4 KB
Newer Older
L
Leo Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/dist_grad_kernel.h"

#include <tuple>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/p_norm_grad_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/scale_kernel.h"

namespace phi {

std::pair<std::vector<int64_t>, std::vector<int64_t>> GetReduceDims(
    const DDim& src_dim, const DDim& dst_dim) {
  std::vector<int64_t> reduce_dims, new_dims;
  auto pre_dims = src_dim.size() - dst_dim.size();
  for (auto i = 0; i < pre_dims; ++i) {
    reduce_dims.push_back(i);
  }

  for (auto i = pre_dims; i < src_dim.size(); ++i) {
    if (dst_dim[i - pre_dims] == 1 && src_dim[i] != 1) {
      reduce_dims.push_back(i);
    } else {
      new_dims.push_back(dst_dim[i - pre_dims]);
    }
  }
  return {reduce_dims, new_dims};
}

template <typename T, typename Context>
void DistGradKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& y,
                    const DenseTensor& out,
                    const DenseTensor& out_grad,
                    float p,
                    DenseTensor* x_grad,
                    DenseTensor* y_grad) {
L
Leo Chen 已提交
55 56 57 58
  if ((!x_grad) && (!y_grad)) {
    return;
  }

L
Leo Chen 已提交
59 60 61 62 63 64 65
  auto t = Subtract<T, Context>(dev_ctx, x, y);
  DenseTensor x_grad_tmp;
  x_grad_tmp.Resize(t.dims());
  DenseTensor y_grad_tmp;
  y_grad_tmp.Resize(t.dims());
  PNormGradKernel<T, Context>(
      dev_ctx, t, out, out_grad, p, -1, 1e-12, false, true, &x_grad_tmp);
L
Leo Chen 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78

  if (x_grad) {
    // do reduce, the implemetation of cpu SumKernel has bug, it changes
    // the dims of output iternally, so we Resize x/y_grad twice.
    auto res_x = GetReduceDims(x_grad_tmp.dims(), x.dims());
    if (!std::get<0>(res_x).empty()) {
      x_grad->Resize(phi::make_ddim(std::get<1>(res_x)));
      SumKernel<T, Context>(
          dev_ctx, x_grad_tmp, std::get<0>(res_x), x.dtype(), false, x_grad);
      x_grad->Resize(x.dims());
    } else {
      x_grad->ShareBufferWith(x_grad_tmp);
    }
L
Leo Chen 已提交
79
  }
L
Leo Chen 已提交
80 81 82 83 84 85 86 87 88 89 90 91

  if (y_grad) {
    ScaleKernel<T, Context>(dev_ctx, x_grad_tmp, -1.0, 0.0, false, &y_grad_tmp);
    auto res_y = GetReduceDims(y_grad_tmp.dims(), y.dims());
    if (!std::get<0>(res_y).empty()) {
      y_grad->Resize(phi::make_ddim(std::get<1>(res_y)));
      SumKernel<T, Context>(
          dev_ctx, y_grad_tmp, std::get<0>(res_y), y.dtype(), false, y_grad);
      y_grad->Resize(y.dims());
    } else {
      y_grad->ShareBufferWith(y_grad_tmp);
    }
L
Leo Chen 已提交
92 93 94 95 96 97 98 99 100 101 102 103
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(
    dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
    dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#endif