p_norm_op_npu.cc 7.1 KB
Newer Older
R
ronnywang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/op_registry.h"
R
ronnywang 已提交
16 17 18 19 20 21 22 23

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class PnormNPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
24 25
    auto* in_x = ctx.Input<phi::DenseTensor>("X");
    auto* out_norm = ctx.Output<phi::DenseTensor>("Out");
R
ronnywang 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    out_norm->mutable_data<T>(ctx.GetPlace());

    float porder = ctx.Attr<float>("porder");
    int axis = ctx.Attr<int>("axis");
    bool keepdim = ctx.Attr<bool>("keepdim");

    auto xdim = in_x->dims();
    if (axis < 0) axis = xdim.size() + axis;

    auto stream =
        ctx.template device_context<paddle::platform::NPUDeviceContext>()
            .stream();

    int p = 0;
    bool combine_op =
        !(porder == 0 || porder == INFINITY || porder == -INFINITY);
    if (porder == INFINITY) {
      p = INT_MAX;
    } else if (porder == -INFINITY) {
      p = INT_MIN;
    } else {
      p = static_cast<int>(porder);
      float t = 0;
      float diff = abs(std::modf(porder, &t));
      if (diff < 1e-5) {
        combine_op = false;
      }
    }

    if (!combine_op) {
56 57 58
      const auto& runner = NpuOpRunner("LpNorm",
                                       {*in_x},
                                       {*out_norm},
R
ronnywang 已提交
59 60 61 62 63
                                       {{"p", p},
                                        {"axes", std::vector<int32_t>({axis})},
                                        {"keep_dims", keepdim}});
      runner.Run(stream);
    } else {
64
      phi::DenseTensor tmp_x;
R
ronnywang 已提交
65 66 67
      tmp_x.mutable_data<T>(xdim, ctx.GetPlace());

      const auto& power_runner1 =
68 69 70
          NpuOpRunner("Power",
                      {*in_x},
                      {tmp_x},
R
ronnywang 已提交
71 72 73 74
                      {{"power", porder}, {"scale", 1.0f}, {"shift", 0.0f}});
      power_runner1.Run(stream);

      const auto& reduce_runner = NpuOpRunner(
75 76 77
          "ReduceSumD",
          {tmp_x},
          {*out_norm},
R
ronnywang 已提交
78 79 80 81
          {{"axes", std::vector<int32_t>({axis})}, {"keep_dims", keepdim}});
      reduce_runner.Run(stream);

      const auto& power_runner2 = NpuOpRunner(
82 83 84
          "Power",
          {*out_norm},
          {*out_norm},
R
ronnywang 已提交
85 86 87 88 89 90
          {{"power", 1 / porder}, {"scale", 1.0f}, {"shift", 0.0f}});
      power_runner2.Run(stream);
    }
  }
};

Z
zhulei 已提交
91 92 93 94
template <typename DeviceContext, typename T>
class PnormGradNPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
95 96 97 98
    auto* x = ctx.Input<phi::DenseTensor>("X");
    auto* y = ctx.Input<phi::DenseTensor>("Out");
    auto* dy = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    auto* dx = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
Z
zhulei 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

    auto place = ctx.GetPlace();
    dx->mutable_data<T>(place);

    auto xdim = x->dims();
    float porder = ctx.Attr<float>("porder");
    bool keepdim = ctx.Attr<bool>("keepdim");

    int axis = ctx.Attr<int>("axis");
    axis = axis < 0 ? xdim.size() + axis : axis;

    auto stream =
        ctx.template device_context<paddle::platform::NPUDeviceContext>()
            .stream();

114 115
    phi::DenseTensor y_share(y->type());
    phi::DenseTensor dy_share(dy->type());
Z
zhulei 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    y_share.ShareDataWith(*y);
    dy_share.ShareDataWith(*dy);
    auto ydim = xdim;
    if (!keepdim) {
      ydim[axis] = 1;
    } else {
      ydim = y->dims();
    }
    y_share.Resize(ydim);
    dy_share.Resize(ydim);

    if (porder == 0) {
      FillNpuTensorWithConstant(dx, static_cast<T>(0));
      dx->Resize(xdim);
    } else if (porder == INFINITY || porder == -INFINITY) {
131
      phi::DenseTensor x_abs;
Z
zhulei 已提交
132 133 134 135
      x_abs.mutable_data<T>(xdim, place);
      const auto& r_abs = NpuOpRunner("Abs", {*x}, {x_abs}, {});
      r_abs.Run(stream);

136
      phi::DenseTensor t_cond;
Z
zhulei 已提交
137 138 139 140 141
      t_cond.mutable_data<bool>(xdim, place);
      const auto& r_equal =
          NpuOpRunner("Equal", {x_abs, y_share}, {t_cond}, {});
      r_equal.Run(stream);

142
      phi::DenseTensor t_zero;
Z
zhulei 已提交
143 144 145
      t_zero.mutable_data<T>({1}, place);
      FillNpuTensorWithConstant(&t_zero, static_cast<T>(0));

146
      phi::DenseTensor x_sign;
Z
zhulei 已提交
147 148 149 150 151 152 153 154 155 156 157
      x_sign.mutable_data<T>(xdim, place);
      const auto& r_sign = NpuOpRunner("Sign", {*x}, {x_sign}, {});
      r_sign.Run(stream);

      const auto& r_mul = NpuOpRunner("Mul", {x_sign, dy_share}, {*dx}, {});
      r_mul.Run(stream);

      const auto& r_sel =
          NpuOpRunner("SelectV2", {t_cond, *dx, t_zero}, {*dx}, {});
      r_sel.Run(stream);
    } else {
158
      phi::DenseTensor x_abs;
Z
zhulei 已提交
159 160 161 162
      x_abs.mutable_data<T>(xdim, place);
      const auto& r_abs = NpuOpRunner("Abs", {*x}, {x_abs}, {});
      r_abs.Run(stream);

163
      phi::DenseTensor x_sign;
Z
zhulei 已提交
164 165 166 167
      x_sign.mutable_data<T>(xdim, place);
      const auto& r_sign = NpuOpRunner("Sign", {*x}, {x_sign}, {});
      r_sign.Run(stream);

168
      phi::DenseTensor y_pow;
Z
zhulei 已提交
169 170 171
      y_pow.mutable_data<T>(ydim, place);
      if (porder >= 1) {
        const auto& r_pow1 = NpuOpRunner(
172 173 174
            "Power",
            {x_abs},
            {x_abs},
Z
zhulei 已提交
175 176 177 178
            {{"power", (porder - 1)}, {"scale", 1.0f}, {"shift", 0.0f}});
        r_pow1.Run(stream);

        const auto& r_pow2 = NpuOpRunner(
179 180 181
            "Power",
            {y_share},
            {y_pow},
Z
zhulei 已提交
182 183 184 185 186 187 188
            {{"power", (porder - 1)}, {"scale", 1.0f}, {"shift", 0.0f}});
        r_pow2.Run(stream);

        const auto& r_div = NpuOpRunner("DivNoNan", {x_abs, y_pow}, {*dx}, {});
        r_div.Run(stream);
      } else {
        const auto& r_pow1 = NpuOpRunner(
189 190 191
            "Power",
            {x_abs},
            {x_abs},
Z
zhulei 已提交
192 193 194 195
            {{"power", (1 - porder)}, {"scale", 1.0f}, {"shift", 0.0f}});
        r_pow1.Run(stream);

        const auto& r_pow2 = NpuOpRunner(
196 197 198
            "Power",
            {y_share},
            {y_pow},
Z
zhulei 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
            {{"power", (1 - porder)}, {"scale", 1.0f}, {"shift", 0.0f}});
        r_pow2.Run(stream);

        const auto& r_div = NpuOpRunner("DivNoNan", {y_pow, x_abs}, {*dx}, {});
        r_div.Run(stream);
      }

      const auto& r_mul1 = NpuOpRunner("Mul", {*dx, x_sign}, {*dx}, {});
      r_mul1.Run(stream);

      const auto& r_mul2 = NpuOpRunner("Mul", {*dx, dy_share}, {*dx}, {});
      r_mul2.Run(stream);
    }
  }
};
R
ronnywang 已提交
214 215 216 217 218 219 220
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(
221 222
    p_norm,
    ops::PnormNPUKernel<plat::NPUDeviceContext, float>,
R
ronnywang 已提交
223
    ops::PnormNPUKernel<plat::NPUDeviceContext, plat::float16>);
Z
zhulei 已提交
224 225

REGISTER_OP_NPU_KERNEL(
226 227
    p_norm_grad,
    ops::PnormGradNPUKernel<plat::NPUDeviceContext, float>,
Z
zhulei 已提交
228
    ops::PnormGradNPUKernel<plat::NPUDeviceContext, plat::float16>);