softmax_mkldnn_op.cc 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/softmax_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17 18 19 20 21 22 23 24

namespace paddle {
namespace operators {

using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;

25 26 27 28 29 30
using dnnl::memory;  // Note: paddle has also "memory" namespace
using dnnl::primitive;
using dnnl::prop_kind;
using dnnl::softmax_backward;
using dnnl::softmax_forward;
using dnnl::stream;
J
Jacek Czaja 已提交
31 32
using platform::to_void_cast;

33
template <typename T>
34 35 36
class SoftmaxMKLDNNHandler
    : public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
                                      mkldnn::softmax_backward> {
J
Jacek Czaja 已提交
37
 public:
38 39 40 41
  SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
                       const mkldnn::engine mkldnn_engine,
                       platform::Place cpu_place, const Tensor* input,
                       Tensor* output, const int axis,
42
                       const std::string uniq_name, bool is_inplaced)
43 44
      : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
                                 mkldnn::softmax_backward>(
45
            dev_ctx, mkldnn_engine, cpu_place,
46
            // Softmax may be inplace then uniq_name is no longer unique
47 48 49 50 51 52
            is_inplaced ? platform::CreateKey(
                              dev_ctx, framework::vectorize(input->dims()),
                              axis, uniq_name)
                        : platform::CreateKey(
                              dev_ctx, framework::vectorize(input->dims()),
                              uniq_name)) {
53
    if (!this->isCached()) {
54 55 56 57 58 59 60 61 62
      PADDLE_ENFORCE_EQ(
          input->dims(), output->dims(),
          platform::errors::InvalidArgument(
              "The shape of input and output tensor must be identical."));

      auto softmax_tz = framework::vectorize(input->dims());
      auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
                             input->format());

63 64
      this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
                                              axis);
65
    }
66
  }
J
Jacek Czaja 已提交
67

68 69 70 71 72
  SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
                       const MKLDNNDeviceContext& dev_ctx,
                       platform::Place cpu_place, const Tensor* out,
                       const Tensor* out_grad, Tensor* in_x_grad,
                       const std::string& unique_name)
73 74 75
      : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
                                 mkldnn::softmax_backward>(
            dev_ctx, dev_ctx.GetEngine(), cpu_place,
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            platform::CreateKey(dev_ctx, framework::vectorize(out->dims()),
                                unique_name)) {
    if (!this->isBwdCached()) {
      PADDLE_ENFORCE_EQ(
          out_grad->dims(), in_x_grad->dims(),
          platform::errors::InvalidArgument("The shape of softmax_grad's input "
                                            "and output must be identical."));

      auto dims = out_grad->dims();  // input and output share the same shape
      const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
      auto softmax_tz = framework::vectorize<int64_t>(dims);

      auto data_softmax_md = MKLDNNMemDesc(
          softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
      auto diff_softmax_md = MKLDNNMemDesc(
          softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());

93 94 95 96
      this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
                                              data_softmax_md, axis);
      this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
                                               axis);
97
    }
98
  }
J
Jacek Czaja 已提交
99
};
100 101 102 103 104 105

template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
106 107
    const auto& mkldnn_engine = dev_ctx.GetEngine();

108 109
    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");
110
    bool is_inplaced = input->IsSharedBufferWith(*output);
F
fengjiayi 已提交
111

112
    const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
113

114
    SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
115 116
                                    input, output, axis, ctx.OutputName("Out"),
                                    is_inplaced);
117

118
    auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
119
    // For Inplace src and and dst are the same memory object
120 121
    auto softmax_dst_memory_p =
        is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output);
122

123 124
    auto softmax_p = handler.AcquireForwardPrimitive();

125
    auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
126 127
    softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
                                 {DNNL_ARG_DST, *softmax_dst_memory_p}});
A
Adam 已提交
128
    astream.wait();
129 130 131

    const bool is_test = ctx.Attr<bool>("is_test");
    if (!is_test) {
132
      T* output_data = output->mutable_data<T>(ctx.GetPlace());
A
Adam 已提交
133
      std::for_each(output_data, &output_data[output->numel()], [](T& val) {
134 135
        val = std::max(val, static_cast<T>(exp(-64)));
      });
136
    }
137 138 139 140

    output->set_layout(framework::DataLayout::kMKLDNN);
    // Softmax output format is the same as input one
    output->set_format(input->format());
141 142 143
  }
};

J
Jacek Czaja 已提交
144 145 146 147
template <typename T>
class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
148 149 150
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL SoftmaxGrad must use CPUPlace"));
J
Jacek Czaja 已提交
151 152
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
    const Tensor* output = ctx.Input<Tensor>("Out");
153 154
    auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out"));
    auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X"));
F
fengjiayi 已提交
155

156 157
    SoftmaxMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), output,
                                    out_grad, in_x_grad, ctx.InputName("Out"));
158

159
    auto dst_memory_p = handler.AcquireDstMemory(output);
160 161
    auto diff_dst_memory_p = handler.AcquireDiffDstMemory(out_grad);
    auto diff_src_memory_p = handler.AcquireDiffSrcMemory(in_x_grad);
J
Jacek Czaja 已提交
162

A
Adam 已提交
163
    auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
J
Jacek Czaja 已提交
164

165
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
166 167 168 169 170
    softmax_bwd_p->execute(astream,
                           {{MKLDNN_ARG_DST, *dst_memory_p},
                            {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                            {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
    astream.wait();
171

172 173
    in_x_grad->set_layout(framework::DataLayout::kMKLDNN);
    in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
J
Jacek Czaja 已提交
174 175
  }
};
176 177 178 179 180 181
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
182 183
                   ops::SoftmaxMKLDNNKernel<float>,
                   ops::SoftmaxMKLDNNKernel<paddle::platform::bfloat16>);
J
Jacek Czaja 已提交
184 185
REGISTER_OP_KERNEL(softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
                   ops::SoftmaxMKLDNNGradKernel<float>);
新手
引导
客服 返回
顶部