softmax_mkldnn_op.cc 7.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/softmax_op.h"
J
Jacek Czaja 已提交
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17

W
wanghuancoder 已提交
18 19 20 21 22 23 24 25 26
namespace paddle {
namespace framework {
class Tensor;
}  // namespace framework
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

27 28 29 30 31 32 33
namespace paddle {
namespace operators {

using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;

34 35 36 37 38 39
using dnnl::memory;  // Note: paddle has also "memory" namespace
using dnnl::primitive;
using dnnl::prop_kind;
using dnnl::softmax_backward;
using dnnl::softmax_forward;
using dnnl::stream;
J
Jacek Czaja 已提交
40 41
using platform::to_void_cast;

42
template <typename T>
43 44 45
class SoftmaxMKLDNNHandler
    : public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
                                      mkldnn::softmax_backward> {
J
Jacek Czaja 已提交
46
 public:
47 48 49 50
  SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
                       const mkldnn::engine mkldnn_engine,
                       platform::Place cpu_place, const Tensor* input,
                       Tensor* output, const int axis,
51
                       const std::string uniq_name, bool is_inplaced)
52 53
      : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
                                 mkldnn::softmax_backward>(
54
            dev_ctx, mkldnn_engine, cpu_place,
55
            // Softmax may be inplace then uniq_name is no longer unique
56 57 58 59 60 61
            is_inplaced ? platform::CreateKey(
                              dev_ctx, framework::vectorize(input->dims()),
                              axis, uniq_name)
                        : platform::CreateKey(
                              dev_ctx, framework::vectorize(input->dims()),
                              uniq_name)) {
62 63 64 65 66 67 68 69 70 71 72 73 74
    if (!this->isCached()) {
      PADDLE_ENFORCE_EQ(
          input->dims(), output->dims(),
          platform::errors::InvalidArgument(
              "The shape of input and output tensor must be identical."));

      auto softmax_tz = framework::vectorize(input->dims());
      auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
                             input->format());

      this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
                                              axis);
    }
75
  }
J
Jacek Czaja 已提交
76

A
Adam 已提交
77
  SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
78
                       const MKLDNNMemoryFormat fmt,
79
                       const MKLDNNMemoryFormat diff_fmt, const int& axis,
80
                       const platform::MKLDNNDeviceContext& dev_ctx,
81
                       platform::Place cpu_place, const std::string& uniq_name)
82 83 84
      : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
                                 mkldnn::softmax_backward>(
            dev_ctx, dev_ctx.GetEngine(), cpu_place,
85
            platform::CreateKey(dev_ctx, dims, uniq_name)) {
86 87 88 89 90 91
    auto data_softmax_md =
        mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
    auto diff_softmax_md =
        mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);

    this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
92
                                             axis);
93
  }
J
Jacek Czaja 已提交
94
};
95 96 97 98 99 100

template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
101 102
    const auto& mkldnn_engine = dev_ctx.GetEngine();

103 104
    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");
105
    bool is_inplaced = input->IsSharedBufferWith(*output);
F
fengjiayi 已提交
106

107
    const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
108

109
    SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
110 111
                                    input, output, axis, ctx.OutputName("Out"),
                                    is_inplaced);
112

113
    auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
114
    // For Inplace src and and dst are the same memory object
115 116
    auto softmax_dst_memory_p =
        is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output);
117

118 119
    auto softmax_p = handler.AcquireForwardPrimitive();

A
Adam 已提交
120
    mkldnn::stream astream(dev_ctx.GetEngine());
121 122
    softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
                                 {DNNL_ARG_DST, *softmax_dst_memory_p}});
A
Adam 已提交
123
    astream.wait();
J
Jacek Czaja 已提交
124 125 126

    const bool is_test = ctx.Attr<bool>("is_test");
    if (!is_test) {
127
      T* output_data = output->mutable_data<T>(ctx.GetPlace());
A
Adam 已提交
128
      std::for_each(output_data, &output_data[output->numel()], [](T& val) {
129 130
        val = std::max(val, static_cast<T>(exp(-64)));
      });
J
Jacek Czaja 已提交
131
    }
132 133 134 135

    output->set_layout(framework::DataLayout::kMKLDNN);
    // Softmax output format is the same as input one
    output->set_format(input->format());
136 137 138
  }
};

J
Jacek Czaja 已提交
139 140 141 142
template <typename T>
class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
143 144 145
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL SoftmaxGrad must use CPUPlace"));
J
Jacek Czaja 已提交
146 147 148 149 150 151
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
    const Tensor* output = ctx.Input<Tensor>("Out");
    auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
    auto* dx =
        ctx.template Output<framework::Tensor>(framework::GradVarName("X"));

F
fengjiayi 已提交
152 153
    PADDLE_ENFORCE_EQ(
        dout->dims(), dx->dims(),
154 155
        platform::errors::InvalidArgument(
            "The shape of softmax_grad's input and output must be identical."));
F
fengjiayi 已提交
156 157

    auto dims = dout->dims();  // input and output share the same shape
158
    const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
F
fengjiayi 已提交
159

A
Adam 已提交
160
    auto softmax_tz = paddle::framework::vectorize<int64_t>(dims);
F
fengjiayi 已提交
161

162 163
    SoftmaxMKLDNNHandler<T> handler(softmax_tz, output->format(),
                                    dout->format(), axis, dev_ctx,
H
hong 已提交
164
                                    ctx.GetPlace(), ctx.InputName("Out"));
165

166 167 168
    auto dst_memory_p = handler.AcquireDstMemory(output);
    auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
    auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
J
Jacek Czaja 已提交
169

A
Adam 已提交
170
    auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
J
Jacek Czaja 已提交
171

A
Adam 已提交
172 173 174 175 176 177
    mkldnn::stream astream(dev_ctx.GetEngine());
    softmax_bwd_p->execute(astream,
                           {{MKLDNN_ARG_DST, *dst_memory_p},
                            {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                            {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
    astream.wait();
178 179 180

    dx->set_layout(framework::DataLayout::kMKLDNN);
    dx->set_format(dout->format());
J
Jacek Czaja 已提交
181 182
  }
};
183 184 185 186 187 188
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
189 190
                   ops::SoftmaxMKLDNNKernel<float>,
                   ops::SoftmaxMKLDNNKernel<paddle::platform::bfloat16>);
J
Jacek Czaja 已提交
191 192
REGISTER_OP_KERNEL(softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
                   ops::SoftmaxMKLDNNGradKernel<float>);