pool_mkldnn_op.cc 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

X
xiaoli.liu@intel.com 已提交
15
#include "paddle/fluid/framework/data_layout_transform.h"
16 17
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
18
#include "paddle/fluid/platform/mkldnn_reuse.h"
19 20 21 22

namespace paddle {
namespace operators {

23 24
using framework::DataLayout;
using mkldnn::memory;
25
using mkldnn::pooling_backward;
26 27 28 29 30
using mkldnn::pooling_forward;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
31

32 33 34 35 36 37 38 39 40 41 42 43
template <typename T>
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();

    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");

44 45
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
A
Adam 已提交
46
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
47
                      "Wrong format set for Input tensor");
48 49

    std::string pooling_type = ctx.Attr<std::string>("pooling_type");
A
Adam 已提交
50 51 52 53 54 55 56 57 58 59

    std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));

    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

60 61
    bool global_pooling = ctx.Attr<bool>("global_pooling");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
62

63 64 65 66 67 68 69 70 71 72 73 74 75
    // Only 2D pooling is supported now
    PADDLE_ENFORCE_EQ(ksize.size(), 2, "ksize must be 2D, i.e. 2D pooling");
    PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true,
                      "pooling_type must be 'max' or 'avg'");
    PADDLE_ENFORCE_EQ(input->dims().size(), 4,
                      "Input dim must be with 4, i.e. NCHW");

    auto input_dims = input->dims();
    framework::DDim data_dims =
        framework::slice_ddim(input_dims, 2, input_dims.size());

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
76 77
    }

78 79
    UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims,
                  strides, ksize);
80

A
Adam 已提交
81 82
    auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
    auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
83

84 85 86 87 88 89
    auto is_test = ctx.Attr<bool>("is_test");

    platform::PoolingMKLDNNHandler<T> handler(
        src_tz, dst_tz, ksize, strides, paddings, pooling_type,
        ctx.Attr<bool>("ceil_mode"), input->format(),
        paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx,
H
hong 已提交
90
        ctx.GetPlace(), ctx.OutputName("Out"), ctx.Attr<bool>("exclusive"));
91 92 93 94

    auto src_memory = handler.AcquireSrcMemory(input);
    auto dst_memory = handler.AcquireDstMemory(output);

A
Adam 已提交
95
    auto pool_p = handler.AcquireForwardPrimitive();
96

A
Adam 已提交
97
    mkldnn::stream astream(dev_ctx.GetEngine());
98 99
    if ((is_test == false) && (pooling_type == "max")) {
      // Training
A
Adam 已提交
100 101 102 103
      auto workspace_memory = handler.AcquireWorkspaceMemory();
      pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
                                {MKLDNN_ARG_DST, *dst_memory},
                                {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
104 105
    } else {
      // Inference
A
Adam 已提交
106 107
      pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
                                {MKLDNN_ARG_DST, *dst_memory}});
108
    }
A
Adam 已提交
109
    astream.wait();
110 111

    output->set_layout(DataLayout::kMKLDNN);
A
Adam 已提交
112
    output->set_format(platform::GetMKLDNNFormat(*dst_memory));
113 114 115 116 117 118 119 120 121 122 123 124 125 126
  }
};

template <typename T>
class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");

    const Tensor* in_x = ctx.Input<Tensor>("X");
    const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

127 128
    PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
A
Adam 已提交
129
    PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::undef,
130
                      "Wrong format set for Input tensor");
131

132 133
    PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input output_grad tensor");
A
Adam 已提交
134
    PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
135 136 137 138
                      "Wrong format set for Input output_grad tensor");

    PADDLE_ENFORCE_EQ(
        ctx.Attr<bool>("is_test"), false,
139 140
        "is_test attribute should be set to False in training phase.");

141
    std::string pooling_type = ctx.Attr<std::string>("pooling_type");
A
Adam 已提交
142 143 144 145 146 147 148 149 150 151

    std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));

    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

152 153
    bool global_pooling = ctx.Attr<bool>("global_pooling");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
154

155 156 157 158 159 160
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims =
        framework::slice_ddim(in_x_dims, 2, in_x_dims.size());

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
161 162
    }

163 164 165
    UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims,
                  strides, ksize);

166 167 168
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();

169 170
    std::vector<mkldnn::primitive> pipeline;

A
Adam 已提交
171 172
    auto diff_src_tz = paddle::framework::vectorize<int64_t>(in_x_grad->dims());
    auto diff_dst_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
173

174 175
    // Get an unique name from "argument" name of "Out" variable
    // This name will be used as key when referring info from device context
176
    const std::string key = platform::CreateKey(
177
        diff_src_tz, pooling_type, ksize, strides, paddings,
H
hong 已提交
178
        memory::data_type::f32, in_x->format(), ctx.InputName("Out"));
179

180 181 182 183
    platform::PoolingMKLDNNHandler<T> handler(
        diff_dst_tz, diff_src_tz, ksize, strides, paddings, pooling_type,
        ctx.Attr<bool>("ceil_mode"), in_x->format(), out_grad->format(),
        paddle::framework::ToMKLDNNDataType(out_grad->type()), dev_ctx,
H
hong 已提交
184
        ctx.GetPlace(), ctx.InputName("Out"), ctx.Attr<bool>("exclusive"));
185 186 187 188

    auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
    auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);

A
Adam 已提交
189
    auto pool_bwd_p = handler.AcquireBackwardPrimitive();
190

A
Adam 已提交
191
    mkldnn::stream astream(dev_ctx.GetEngine());
192 193
    if (pooling_type == "max") {
      // Max - pooling needs Workspace
A
Adam 已提交
194 195 196 197
      auto workspace_memory = handler.AcquireWorkspaceMemory();
      pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
                                    {MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
                                    {MKLDNN_ARG_WORKSPACE, *workspace_memory}});
198 199
    } else {
      // Average Pooling
A
Adam 已提交
200 201
      pool_bwd_p->execute(astream, {{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},
                                    {MKLDNN_ARG_DIFF_DST, *diff_dst_memory}});
202
    }
A
Adam 已提交
203
    astream.wait();
204 205

    in_x_grad->set_layout(DataLayout::kMKLDNN);
A
Adam 已提交
206
    in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
207 208 209 210 211 212
  }  // Compute()
};

}  // namespace operators
}  // namespace paddle

213 214
namespace ops = paddle::operators;

215
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
X
xiaoli.liu@intel.com 已提交
216 217 218 219
                   ops::PoolMKLDNNOpKernel<float>,
                   ops::PoolMKLDNNOpKernel<int8_t>,
                   ops::PoolMKLDNNOpKernel<uint8_t>);

220
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
221
                   ops::PoolMKLDNNGradOpKernel<float>);