stack_mkldnn_op.cc 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {

using framework::DataLayout;
using framework::Tensor;
using framework::LoDTensor;
23 24 25 26
using dnnl::memory;
using dnnl::primitive;
using dnnl::concat;
using dnnl::stream;
27 28 29 30 31 32 33
using platform::to_void_cast;

template <typename T>
class StackMKLDNNHandler
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::concat> {
 public:
  StackMKLDNNHandler(const framework::ExecutionContext& ctx,
34
                     const dnnl::engine mkldnn_engine,
35 36 37 38 39 40 41 42 43 44 45 46
                     const std::vector<const Tensor*>& inputs, Tensor* output)
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::concat>(mkldnn_engine,
                                                           ctx.GetPlace()) {
    int stack_axis = ctx.Attr<int>("axis");

    int ndims = inputs[0]->dims().size();

    if (stack_axis < 0) {
      stack_axis = ndims + 1 + stack_axis;  // +1 to match output's ndims
    }

    // in stack op all inputs must have same dims
47
    auto input_dims = pten::vectorize<int64_t>(inputs[0]->dims());
48

49 50
    memory::data_type dt = framework::ToMKLDNNDataType(
        framework::TransToProtoVarType(inputs[0]->dtype()));
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    std::vector<memory::desc> srcs_md;
    memory::desc dst_md;
    MKLDNNMemoryFormat dst_fmt;

    srcs_md.reserve(inputs.size());

    // if stack is not done on last(non existing) axis, then we can optimize
    // concat primitive by not adding additional dimension, since it causes
    // wrong output format deduction and suboptimal performance as a result
    if (stack_axis != ndims) {
      for (size_t i = 0; i < inputs.size(); ++i) {
        srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format()));
      }

      input_dims[stack_axis] *= inputs.size();
      dst_md = memory::desc(input_dims, dt, MKLDNNMemoryFormat::any);
    } else {
68
      auto extended_input_dims = pten::vectorize<int64_t>(output->dims());
69 70 71 72 73 74 75 76 77 78 79
      extended_input_dims[stack_axis] = 1;

      for (size_t i = 0; i < inputs.size(); ++i) {
        srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format())
                                 .reshape(extended_input_dims));
      }

      // concat primitive choses suboptimal format tag because it cannot
      // distinguish between f.e. abcd and abdc if last dim is equal to 1 so
      // enforcing is needed for better performance
      dst_fmt = platform::GetPlainMKLDNNFormat(extended_input_dims.size());
80
      dst_md = memory::desc(pten::vectorize(output->dims()), dt, dst_fmt);
81 82 83 84 85 86 87 88 89 90 91 92 93 94
    }

    this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md);
  }

  // concat oneDNN prim is not having .desc attribute so we cannot use default
  // AcquireForwardPrimitiveDescriptor
  void AcquireForwardPrimitiveDescriptor(
      const memory::desc& dst_md, const int stack_axis,
      const std::vector<memory::desc>& srcs_md) {
    this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
        dst_md, stack_axis, srcs_md, this->engine_));
  }

95
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const Tensor& input, int i) {
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    const T* input_data = input.data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
                                            to_void_cast<T>(input_data));
  }
};

template <typename T>
class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    auto multi_input = ctx.MultiInput<Tensor>("X");

    Tensor* output = ctx.Output<Tensor>("Y");

    StackMKLDNNHandler<T> handler(ctx, mkldnn_engine, multi_input, output);

    std::vector<std::shared_ptr<memory>> srcs;
    srcs.reserve(multi_input.size());

    auto dst_mem = handler.AcquireDstMemory(output);
    auto concat_p = handler.AcquireForwardPrimitive();

    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    std::unordered_map<int, memory> args;
    for (size_t i = 0; i < multi_input.size(); ++i) {
      srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i));
126
      args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))});
127
    }
128
    args.insert({DNNL_ARG_DST, *dst_mem});
129 130 131 132 133 134

    concat_p->execute(astream, args);
    astream.wait();

    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(
135
        dst_mem->get_desc().reshape(pten::vectorize(output->dims()))));
136 137 138 139 140 141 142 143 144
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_KERNEL(stack, MKLDNN, ::paddle::platform::CPUPlace,
                   ops::StackMKLDNNOpKernel<float>);