softmax_mkldnn_op.cc 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <iostream>
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#include "mkldnn.hpp"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"

namespace paddle {
namespace operators {

using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;

using mkldnn::memory;  // Note: paddle has also "memory" namespace
using mkldnn::primitive;
using mkldnn::softmax_forward;
using mkldnn::prop_kind;
using mkldnn::stream;

template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
    auto mkldnn_engine = dev_ctx.GetEngine();
    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");
    PADDLE_ENFORCE(input->dims().size() == 2UL,
                   "The input of softmax op must be a 2D matrix.");
    const T* input_data = input->data<T>();
    // allocate memory for output
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
    std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
    // MKL-DNN does support softmax over selected axis. Having 2D Tensor,
    // we will make normalization after final eg. axis: 1
    PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
                   "Softmax input and output dimensions should match");
    // Same memory descriptor to be used for input and output
    memory::dims softmax_tz = {src_tz[0], src_tz[1]};
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    // Generate keys for storing/retriving primitives for this operator
    // TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
    auto gethash = [](memory::dims& operand_dims) {
      return std::string(std::to_string(operand_dims[0]) + "-" +
                         std::to_string(operand_dims[1]));
    };
    const std::string key = gethash(softmax_tz);
    const std::string key_softmax_p = key + "@softmax_p";
    const std::string key_softmax_src_mem_p = key + "@softmax_src_mem_p";
    const std::string key_softmax_dst_mem_p = key + "@softmax_dst_mem_p";

    std::shared_ptr<void> softmax_p = dev_ctx.GetBlob(key_softmax_p);
    if (softmax_p == nullptr) {
      // Currently only NC data format is supported
      auto softmax_md =
          MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
      // Normalization is made after innermost dimension eg. C out of NC
      auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
                                                softmax_md, 1 /*dim: C*/);
      // create memory primitives
      auto softmax_src_memory_p = std::make_shared<memory>(
          memory::primitive_desc{softmax_md, mkldnn_engine},
          static_cast<void*>(const_cast<T*>(input_data)));
      dev_ctx.SetBlob(key_softmax_src_mem_p, softmax_src_memory_p);
      auto softmax_dst_memory_p = std::make_shared<memory>(
          memory::primitive_desc{softmax_md, mkldnn_engine},
          static_cast<void*>(output_data));
      dev_ctx.SetBlob(key_softmax_dst_mem_p, softmax_dst_memory_p);

      auto softmax_forward_pd =
          std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
                                                            mkldnn_engine);
      softmax_p = std::make_shared<softmax_forward>(
          *(softmax_forward_pd.get()),
          *(static_cast<memory*>(softmax_src_memory_p.get())),
          *(static_cast<memory*>(softmax_dst_memory_p.get())));
      dev_ctx.SetBlob(key_softmax_p, softmax_p);
    } else {
      // Primitives already exist
      auto src_memory_p = std::static_pointer_cast<memory>(
          dev_ctx.GetBlob(key_softmax_src_mem_p));
      PADDLE_ENFORCE(src_memory_p != nullptr,
                     "Fail to find softmax src mem_p in device context");
      auto dst_memory_p = std::static_pointer_cast<memory>(
          dev_ctx.GetBlob(key_softmax_dst_mem_p));
      PADDLE_ENFORCE(dst_memory_p != nullptr,
                     "Fail to find softmax dst mem_p in device context");
      src_memory_p->set_data_handle(
          reinterpret_cast<void*>(const_cast<T*>(input_data)));
      dst_memory_p->set_data_handle(output_data);
    }

    std::vector<primitive> pipeline{
        *(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
110
    stream(stream::kind::eager).submit(pipeline).wait();
J
Jacek Czaja 已提交
111 112 113 114

    const bool is_test = ctx.Attr<bool>("is_test");
    if (!is_test) {
      T threshold = exp(-64);
115
      for (int i = 0; i < dst_tz[0] * dst_tz[1]; ++i) {
J
Jacek Czaja 已提交
116 117 118 119
        output_data[i] =
            output_data[i] < threshold ? threshold : output_data[i];
      }
    }
120 121 122 123 124 125 126 127 128 129
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
                   ops::SoftmaxMKLDNNKernel<float>);