pool_mkldnn_op.cc 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

X
xiaoli.liu@intel.com 已提交
15
#include "paddle/fluid/framework/data_layout_transform.h"
16 17
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
18
#include "paddle/fluid/platform/mkldnn_reuse.h"
19 20 21 22

namespace paddle {
namespace operators {

23 24
using framework::DataLayout;
using mkldnn::memory;
25
using mkldnn::pooling_backward;
26 27 28 29 30
using mkldnn::pooling_forward;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
31

32 33 34 35
template <typename T>
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
36 37 38
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Pool must use CPUPlace"));
39 40 41 42 43 44
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();

    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");

45 46
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
A
Adam 已提交
47
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
48
                      "Wrong format set for Input tensor");
49 50

    std::string pooling_type = ctx.Attr<std::string>("pooling_type");
A
Adam 已提交
51 52 53 54 55 56 57 58 59 60

    std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));

    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

61 62
    bool global_pooling = ctx.Attr<bool>("global_pooling");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
63

64 65 66 67 68 69 70 71 72 73 74 75 76
    // Only 2D pooling is supported now
    PADDLE_ENFORCE_EQ(ksize.size(), 2, "ksize must be 2D, i.e. 2D pooling");
    PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true,
                      "pooling_type must be 'max' or 'avg'");
    PADDLE_ENFORCE_EQ(input->dims().size(), 4,
                      "Input dim must be with 4, i.e. NCHW");

    auto input_dims = input->dims();
    framework::DDim data_dims =
        framework::slice_ddim(input_dims, 2, input_dims.size());

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
77 78
    }

79 80
    UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims,
                  strides, ksize);
81

A
Adam 已提交
82 83
    auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
    auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
84

85 86 87 88 89 90
    auto is_test = ctx.Attr<bool>("is_test");

    platform::PoolingMKLDNNHandler<T> handler(
        src_tz, dst_tz, ksize, strides, paddings, pooling_type,
        ctx.Attr<bool>("ceil_mode"), input->format(),
        paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx,
H
hong 已提交
91
        ctx.GetPlace(), ctx.OutputName("Out"), ctx.Attr<bool>("exclusive"));
92 93 94 95

    auto src_memory = handler.AcquireSrcMemory(input);
    auto dst_memory = handler.AcquireDstMemory(output);

A
Adam 已提交
96
    auto pool_p = handler.AcquireForwardPrimitive();
97

A
Adam 已提交
98
    mkldnn::stream astream(dev_ctx.GetEngine());
99 100
    if ((is_test == false) && (pooling_type == "max")) {
      // Training
A
Adam 已提交
101 102 103 104
      auto workspace_memory = handler.AcquireWorkspaceMemory();
      pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
                                {MKLDNN_ARG_DST, *dst_memory},
                                {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
105 106
    } else {
      // Inference
A
Adam 已提交
107 108
      pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
                                {MKLDNN_ARG_DST, *dst_memory}});
109
    }
A
Adam 已提交
110
    astream.wait();
111 112

    output->set_layout(DataLayout::kMKLDNN);
A
Adam 已提交
113
    output->set_format(platform::GetMKLDNNFormat(*dst_memory));
114 115 116 117 118 119 120
  }
};

template <typename T>
class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
121 122 123
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL PoolGrad must use CPUPlace"));
124 125 126 127
    const Tensor* in_x = ctx.Input<Tensor>("X");
    const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

128 129
    PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
A
Adam 已提交
130
    PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::undef,
131
                      "Wrong format set for Input tensor");
132

133 134
    PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input output_grad tensor");
A
Adam 已提交
135
    PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
136 137 138 139
                      "Wrong format set for Input output_grad tensor");

    PADDLE_ENFORCE_EQ(
        ctx.Attr<bool>("is_test"), false,
140 141
        "is_test attribute should be set to False in training phase.");

142
    std::string pooling_type = ctx.Attr<std::string>("pooling_type");
A
Adam 已提交
143 144 145 146 147 148 149 150 151 152

    std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));

    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

153 154
    bool global_pooling = ctx.Attr<bool>("global_pooling");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
155

156 157 158 159 160 161
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims =
        framework::slice_ddim(in_x_dims, 2, in_x_dims.size());

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
162 163
    }

164 165 166
    UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims,
                  strides, ksize);

167 168 169
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();

170 171
    std::vector<mkldnn::primitive> pipeline;

A
Adam 已提交
172 173
    auto diff_src_tz = paddle::framework::vectorize<int64_t>(in_x_grad->dims());
    auto diff_dst_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
174

175 176
    // Get an unique name from "argument" name of "Out" variable
    // This name will be used as key when referring info from device context
177
    const std::string key = platform::CreateKey(
178
        diff_src_tz, pooling_type, ksize, strides, paddings,
H
hong 已提交
179
        memory::data_type::f32, in_x->format(), ctx.InputName("Out"));
180

181 182 183 184
    platform::PoolingMKLDNNHandler<T> handler(
        diff_dst_tz, diff_src_tz, ksize, strides, paddings, pooling_type,
        ctx.Attr<bool>("ceil_mode"), in_x->format(), out_grad->format(),
        paddle::framework::ToMKLDNNDataType(out_grad->type()), dev_ctx,
H
hong 已提交
185
        ctx.GetPlace(), ctx.InputName("Out"), ctx.Attr<bool>("exclusive"));
186 187 188 189

    auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
    auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);

A
Adam 已提交
190
    auto pool_bwd_p = handler.AcquireBackwardPrimitive();
191

A
Adam 已提交
192
    mkldnn::stream astream(dev_ctx.GetEngine());
193 194
    if (pooling_type == "max") {
      // Max - pooling needs Workspace
A
Adam 已提交
195 196 197 198
      auto workspace_memory = handler.AcquireWorkspaceMemory();
      pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
                                    {MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
                                    {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
199 200
    } else {
      // Average Pooling
A
Adam 已提交
201 202
      pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
                                    {MKLDNN_ARG_DIFF_DST, *diff_dst_memory}});
203
    }
A
Adam 已提交
204
    astream.wait();
205 206

    in_x_grad->set_layout(DataLayout::kMKLDNN);
A
Adam 已提交
207
    in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
208 209 210 211 212 213
  }  // Compute()
};

}  // namespace operators
}  // namespace paddle

214 215
namespace ops = paddle::operators;

216
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
X
xiaoli.liu@intel.com 已提交
217 218 219 220
                   ops::PoolMKLDNNOpKernel<float>,
                   ops::PoolMKLDNNOpKernel<int8_t>,
                   ops::PoolMKLDNNOpKernel<uint8_t>);

221
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
222
                   ops::PoolMKLDNNGradOpKernel<float>);