norm_grad_kernel.cc 3.1 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pten/kernels/norm_grad_kernel.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"

#include "paddle/pten/kernels/funcs/eigen/common.h"

#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"

#include "paddle/pten/kernels/funcs/common_shape.h"
namespace pten {

template <typename T, typename Context>
void NormGradKernel(const Context& ctx,
                    const DenseTensor& out_grad,
                    const DenseTensor& x,
                    const DenseTensor& norm,
                    int axis,
                    float epsilon,
                    bool is_test,
                    DenseTensor* x_grad) {
  auto* in_x = &x;
  auto* in_dy = &out_grad;
  auto* in_norm = &norm;
  auto* out_dx = x_grad;

  ctx.template Alloc<T>(out_dx);

  auto xdim = in_x->dims();
  if (axis < 0) axis = xdim.size() + axis;
  int pre, n, post;
  funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post);

  auto* place = ctx.eigen_device();

  auto x_e = paddle::framework::EigenVector<T>::Flatten(*in_x);
  auto dy_e = paddle::framework::EigenVector<T>::Flatten(*in_dy);
  auto norm_e = paddle::framework::EigenVector<T>::Flatten(*in_norm);
  auto dx_e = paddle::framework::EigenVector<T>::Flatten(*out_dx);

  Eigen::DSizes<int, 3> shape(pre, n, post);
  Eigen::DSizes<int, 3> rshape(pre, 1, post);
  auto x_r = x_e.reshape(shape);
  auto dy = dy_e.reshape(shape);
  auto norm_r = norm_e.reshape(rshape);
  auto dx = dx_e.reshape(shape);

  DenseTensor rsum;
  rsum.Resize({pre, post});
  ctx.template Alloc<T>(&rsum);
  auto sum = paddle::framework::EigenTensor<T, 2>::From(rsum);

  Eigen::DSizes<int, 1> rdim(1);
  Eigen::DSizes<int, 3> bcast(1, n, 1);

  // dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
  //    = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
  //    = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
  // 1. sum = sum(x*dy)
  sum.device(*place) = (x_r * dy).sum(rdim);
  // 2. dx = x * sum
  dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x_r;
  // 3. dx / (sum(x*x) + e)
  // where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
  dx.device(*place) = dx / norm_r.pow(2).broadcast(bcast);
  // 4. [dy - dx] / sqrt(sum(x*x))
  dx.device(*place) = (dy - dx) / norm_r.broadcast(bcast);
}

}  // namespace pten

PT_REGISTER_KERNEL(
    norm_grad, CPU, ALL_LAYOUT, pten::NormGradKernel, float, double) {}