elementwise_mkldnn_op.h 15.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <string>
17 18 19
#include <unordered_map>

#include "paddle/fluid/framework/data_layout_transform.h"
20 21
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
22 23 24 25 26
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

27 28 29
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
30
using phi::DataLayout;
31

32 33
inline std::vector<int64_t> CalculateBroadcastedDims(
    const phi::DenseTensor* x, const phi::DenseTensor* y) {
34 35 36 37
  const auto src_tz = phi::vectorize(x->dims());
  const auto dst_tz = phi::vectorize(y->dims());

  std::vector<int64_t> dst_tz_ex(src_tz.size(), 1);
38 39 40 41 42 43 44 45 46 47 48

  if (src_tz.size() == dst_tz.size()) {
    for (size_t i = 0; i < src_tz.size(); i++) {
      dst_tz_ex[i] = (src_tz[i] == dst_tz[i]) ? dst_tz[i] : 1;
    }
  } else {
    size_t j = 0;
    for (size_t i = 0; i < src_tz.size(); i++) {
      dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
      if (j == dst_tz.size()) break;
    }
49 50 51 52
  }

  return dst_tz_ex;
}
53

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
inline void AddSubNonBroadcast(platform::ReorderMKLDNNHandler* reorder_handler,
                               phi::DenseTensor* grad_tensor,
                               const std::shared_ptr<dnnl::memory>& src_memory,
                               const std::shared_ptr<dnnl::memory>& dst_memory,
                               const std::vector<float>& scales) {
  dnnl::primitive_attr reorder_attr;
  reorder_attr.set_output_scales(0, scales);
  auto reorder_p =
      reorder_handler->AcquireReorder(dst_memory, src_memory, reorder_attr);

  reorder_p->execute(platform::MKLDNNDeviceContext::tls().get_stream(),
                     *src_memory,
                     *dst_memory);
}

template <typename T>
inline void BroadcastReduction(const framework::ExecutionContext& ctx,
                               const dnnl::engine& onednn_engine,
                               phi::DenseTensor* grad_tensor,
                               const phi::DenseTensor* dout,
                               const std::shared_ptr<dnnl::memory>& src_memory,
                               std::shared_ptr<dnnl::memory> dst_memory,
                               const std::vector<float>& scales,
                               const bool is_sub) {
  dnnl::primitive_attr broadcast_reduction_attr;

  // Broadcasting
  if (is_sub) {
    dnnl::post_ops po;
    po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, scales[0], 0);
    broadcast_reduction_attr.set_post_ops(po);
  }

  platform::ReductionMKLDNNHandler<T> reduction_handler(
      dnnl::algorithm::reduction_sum,
      0.0f,
      0.0f,
      onednn_engine,
      ctx.GetPlace(),
      dout,
      grad_tensor,
      CalculateBroadcastedDims(dout, grad_tensor),
      broadcast_reduction_attr);
  dst_memory = reduction_handler.AcquireDstMemory(grad_tensor);

  auto reduction_p = reduction_handler.AcquireForwardPrimitive();
  auto astream = platform::MKLDNNDeviceContext::tls().get_stream();
  reduction_p->execute(astream,
                       {
                           {DNNL_ARG_SRC, *src_memory},
                           {DNNL_ARG_DST, *dst_memory},
                       });
  astream.wait();
  grad_tensor->set_mem_desc(dst_memory->get_desc().reshape(
      phi::vectorize<int64_t>(grad_tensor->dims())));
}

111 112
template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
113 114 115
 private:
  dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const {
    dnnl::post_ops post_operations;
116
    platform::AppendActivation(ctx, post_operations);
117 118 119
    return post_operations;
  }

120 121 122 123 124 125
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();

126 127 128
    auto* x = ctx.Input<phi::DenseTensor>("X");
    auto* y = ctx.Input<phi::DenseTensor>("Y");
    auto* z = ctx.Output<phi::DenseTensor>("Out");
129 130 131 132 133 134

    float scale_x = ctx.Attr<float>("Scale_x");
    float scale_y = ctx.Attr<float>("Scale_y");
    float scale_o = ctx.Attr<float>("Scale_out");
    int axis = ctx.Attr<int>("axis");

135 136 137 138 139 140 141 142 143 144
    platform::BinaryMKLDNNHandler<T> handler(BINARY_OP,
                                             axis,
                                             mkldnn_engine,
                                             ctx.GetPlace(),
                                             x,
                                             y,
                                             z,
                                             scale_x,
                                             scale_y,
                                             scale_o,
145
                                             true,
146
                                             get_post_ops(ctx));
147

148 149 150 151 152 153
    // oneDNN's binary is optimized for broadcasting y into x, so in other case
    // we have to swap tensors to achieve optimal performance
    if (x->numel() < y->numel()) {
      std::swap(x, y);
    }

154 155
    const auto src_x_memory = handler.AcquireSrcMemory(x);
    const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
156 157 158 159 160 161 162 163 164
    // (jczaja) For Inplace src and dst should be the same memory object.
    // So x should share buffer with z. But UT mechanics is testing inplace
    // execution for this op not checking that x can be bradcasted to match in
    // shape y tensor.
    // This is wrong as when x is to be broadcasted then z(out) will match the
    // shape of y which is bigger than x. Hence if x is smaller in shape than z
    // and they share a buffer (of
    // shape x) then this buffer is not big enough to hold result of elementwise
    // operation.
165 166
    const bool reuse_x_memopry =
        x->numel() == z->numel() && x->IsSharedBufferWith(*z);
167
    std::shared_ptr<dnnl::memory> dst_memory;
168 169 170 171 172 173 174 175 176 177 178 179 180
    if (reuse_x_memopry) {
      dst_memory = src_x_memory;
      // NOTE(chenfeiyu): when the output reuses memory from other tensor rather
      // than allocate its own, it's still need to take care of its data type.
      // Unfortunately, paddle's operator only infers the output' shape, but not
      // the data type. mutable_data<T> takes care of allocation and data type
      // normally, but if the memory is already allocated and there is no need
      // to re-allocate, it just set the data type. So this it added there to
      // get the right data type.
      z->mutable_data<T>(ctx.GetPlace());
    } else {
      dst_memory = handler.AcquireDstMemory(z);
    }
181 182 183

    const auto binary_prim = handler.AcquireForwardPrimitive();

184
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
185 186 187 188 189 190 191 192 193

    const std::unordered_map<int, dnnl::memory> args = {
        {DNNL_ARG_SRC_0, *src_x_memory},
        {DNNL_ARG_SRC_1, *src_y_memory},
        {DNNL_ARG_DST, *dst_memory}};

    binary_prim->execute(astream, args);
    astream.wait();

194
    if (handler.use_broadcasting_hack == false) {
195 196
      platform::SetOutMemDescWithLogicalLayoutFusesSupport(
          ctx, z, dst_memory->get_desc());
197 198 199 200
    } else {
      auto dims = dst_memory->get_desc().dims();
      dims.insert(dims.begin(), x->dims()[0]);
      dims[1] /= dims[0];
201 202
      platform::SetOutMemDescWithLogicalLayoutFusesSupport(
          ctx, z, dst_memory->get_desc().reshape(dims));
203
    }
204 205
  }
};
206

207 208 209 210 211
template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    ElemwiseGradKernel<T>::Compute(ctx);
212

213 214 215
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
    const auto& onednn_engine = dev_ctx.GetEngine();
216

217 218 219
    auto* x = ctx.Input<phi::DenseTensor>("X");
    auto* y = ctx.Input<phi::DenseTensor>("Y");
    auto* out = ctx.Input<phi::DenseTensor>("Out");
220

221 222 223
    auto* dx = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
    auto* dy = ctx.Output<phi::DenseTensor>(framework::GradVarName("Y"));
    auto* dout = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
224

225 226
    // oneDNN's binary is optimized for broadcasting y into x, so in other case
    // we have to swap tensors to achieve optimal performance
227
    bool swap_x_y = false;
228 229 230
    if (x->numel() < y->numel()) {
      std::swap(x, y);
      std::swap(dx, dy);
231 232 233 234 235 236
      swap_x_y = true;
    }

    std::vector<float> scales{1.0};
    if (swap_x_y) {
      scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
237 238
    }

239 240 241 242 243 244
    int axis = ctx.Attr<int>("axis");

    auto tz = phi::vectorize<int64_t>(dout->dims());
    auto proto_type_dout = framework::TransToProtoVarType(dout->dtype());

    platform::ReorderMKLDNNHandler reorder_handler(
245 246 247
        tz,
        proto_type_dout,
        framework::ToMKLDNNDataType(proto_type_dout),
248 249
        onednn_engine);

250
    auto reorder_src_memory = reorder_handler.AcquireSrcMemory(
251
        dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
252

253 254 255
    std::shared_ptr<dnnl::memory> dst_memory;
    std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory;

256 257 258 259 260
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    if (dx) {
      // elementwise_add & elementwise_sub
      if (BINARY_OP == dnnl::algorithm::binary_add ||
          BINARY_OP == dnnl::algorithm::binary_sub) {
261 262 263 264 265 266
        if (dout->dims() == dx->dims()) {
          dst_memory = reorder_handler.AcquireDstMemory(
              dx, dout->mem_desc(), ctx.GetPlace());
          AddSubNonBroadcast(
              &reorder_handler, dx, reorder_src_memory, dst_memory, scales);
        }
267
      } else {  // elementwise_mul & elementwise_div
268 269 270 271 272 273 274 275 276
        platform::BinaryMKLDNNHandler<T> binary_handler(BINARY_OP,
                                                        axis,
                                                        onednn_engine,
                                                        ctx.GetPlace(),
                                                        dout,
                                                        y,
                                                        dx,
                                                        1.0f,
                                                        1.0f,
277 278
                                                        1.0f,
                                                        false);
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

        const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
        const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
        dst_memory = binary_handler.AcquireDstMemory(dx);

        const auto binary_prim = binary_handler.AcquireForwardPrimitive();

        const std::unordered_map<int, dnnl::memory> args = {
            {DNNL_ARG_SRC_0, *src_dout_memory},
            {DNNL_ARG_SRC_1, *src_y_memory},
            {DNNL_ARG_DST, *dst_memory}};

        binary_prim->execute(astream, args);
      }
      astream.wait();

295 296 297 298 299 300 301 302 303 304 305 306
      if (dout->dims() != dx->dims()) {
        BroadcastReduction<T>(ctx,
                              onednn_engine,
                              dx,
                              dout,
                              broadcast_src_memory,
                              dst_memory,
                              scales,
                              BINARY_OP == dnnl::algorithm::binary_sub);
      } else {
        dx->set_mem_desc(dst_memory->get_desc());
      }
307 308 309 310 311 312 313
    }

    if (dy) {
      // elementwise_add & elementwise_sub
      if (BINARY_OP == dnnl::algorithm::binary_add ||
          BINARY_OP == dnnl::algorithm::binary_sub) {
        if (dout->dims() == dy->dims()) {
314
          dst_memory = reorder_handler.AcquireDstMemory(
315
              dy, dout->mem_desc(), ctx.GetPlace());
316 317
          AddSubNonBroadcast(
              &reorder_handler, dy, reorder_src_memory, dst_memory, scales);
318
        }
319
      } else {  // elementwise_mul & elementwise_div
320 321 322 323 324 325 326
        std::unordered_map<int, dnnl::memory> args;
        std::shared_ptr<dnnl::binary> binary_prim;
        std::shared_ptr<dnnl::memory> post_op_memory;
        std::shared_ptr<dnnl::memory> src_0_memory;
        std::shared_ptr<dnnl::memory> src_1_memory;

        platform::BinaryMKLDNNHandler<T> binary_handler(
327 328 329 330 331 332 333 334 335
            dnnl::algorithm::binary_mul,
            axis,
            onednn_engine,
            ctx.GetPlace(),
            dout,
            x,
            nullptr,
            1.0f,
            1.0f,
336 337
            1.0f,
            false);
338 339 340 341 342

        src_1_memory = binary_handler.AcquireSecondSrcMemory(x);

        if (BINARY_OP == dnnl::algorithm::binary_div) {
          platform::BinaryMKLDNNHandler<T> post_op_binary_handler(
343 344 345 346 347 348 349 350 351
              dnnl::algorithm::binary_div,
              axis,
              onednn_engine,
              ctx.GetPlace(),
              y,
              y,
              nullptr,
              1.0f,
              1.0f,
352 353
              1.0f,
              false);
354 355 356 357 358 359 360

          post_op_memory = post_op_binary_handler.AcquireSrcMemory(y);

          dnnl::post_ops po;
          po.append_binary(dnnl::algorithm::binary_div,
                           post_op_memory->get_desc());

361 362 363 364 365 366 367 368 369 370 371
          binary_handler =
              platform::BinaryMKLDNNHandler<T>(dnnl::algorithm::binary_mul,
                                               axis,
                                               onednn_engine,
                                               ctx.GetPlace(),
                                               dout,
                                               out,
                                               nullptr,
                                               -1.0f,
                                               1.0f,
                                               1.0f,
372
                                               false,
373
                                               po);
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399

          src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
        }

        src_0_memory = binary_handler.AcquireSrcMemory(dout);

        const auto dst_dy_memory = (dout->dims() == dy->dims())
                                       ? binary_handler.AcquireDstMemory(dy)
                                       : binary_handler.AcquireDstMemory();

        binary_prim = binary_handler.AcquireForwardPrimitive();
        args = {{DNNL_ARG_SRC_0, *src_0_memory},
                {DNNL_ARG_SRC_1, *src_1_memory},
                {DNNL_ARG_DST, *dst_dy_memory}};

        if (BINARY_OP == dnnl::algorithm::binary_div)
          args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
                       *post_op_memory});

        binary_prim->execute(astream, args);
        broadcast_src_memory = dst_dy_memory;
        dst_memory = dst_dy_memory;
      }
      astream.wait();

      if (dout->dims() != dy->dims()) {
400 401 402 403 404 405 406 407
        BroadcastReduction<T>(ctx,
                              onednn_engine,
                              dy,
                              dout,
                              broadcast_src_memory,
                              dst_memory,
                              scales,
                              BINARY_OP == dnnl::algorithm::binary_sub);
408
      } else {
409
        dy->set_mem_desc(dst_memory->get_desc());
410 411 412 413
      }
    }
  }
};
414 415
}  // namespace operators
}  // namespace paddle