conv_mkldnn_op.cc 24.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Y
Yu Yang 已提交
15 16
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/memory/malloc.h"
17
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
18
#include "paddle/fluid/platform/mkldnn_reuse.h"
19 20 21 22

namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;

Y
Yihua Xu 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups,  // NOLINT
                         bool is_conv3d) {
  if (groups > 1) {
    if (is_conv3d) {
      int output = weights_tz[0];
      int input = weights_tz[1];
      int dimension = weights_tz[2];
      int height = weights_tz[3];
      int width = weights_tz[4];
      weights_tz.resize(6);
      weights_tz[0] = groups;
      weights_tz[1] = output / groups;
      weights_tz[2] = input;
      weights_tz[3] = dimension;
      weights_tz[4] = height;
      weights_tz[5] = width;
    } else {
      int output = weights_tz[0];
      int input = weights_tz[1];
      int height = weights_tz[2];
      int width = weights_tz[3];
      weights_tz.resize(5);
      weights_tz[0] = groups;
      weights_tz[1] = output / groups;
      weights_tz[2] = input;
      weights_tz[3] = height;
      weights_tz[4] = width;
    }
  }
}

inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
                                               int groups, bool is_conv3d) {
  if (is_conv3d) {
    return (groups == 1) ? format : mkldnn::memory::format::goidhw;
  } else {
    return (groups == 1) ? format : mkldnn::memory::format::goihw;
  }
}

71
template <typename T>
72
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
73 74 75 76 77
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");

K
Krzysztof Binias 已提交
78 79
    const bool is_test = ctx.Attr<bool>("is_test");

80 81
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
82 83 84 85
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
86
    auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
87 88
    auto* output = ctx.Output<Tensor>("Output");

89 90 91 92 93 94
    PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
                       input->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input tensor");
    PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
                       filter->format() != memory::format::format_undef,
                   "Wrong layout/format set for Filter tensor");
95
    PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
Y
Yihua Xu 已提交
96
                   "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
97 98
    PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
                   "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
99 100 101 102 103 104 105
    if (bias) {
      PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
                         bias->format() != memory::format::format_undef,
                     "Wrong layout/format set for Bias tensor");
      PADDLE_ENFORCE(bias->dims().size() == 1,
                     "Bias must only have 1 dimension, i.e. X");
    }
106 107 108 109

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
M
Michal Gallus 已提交
110
    bool fuse_relu = ctx.Attr<bool>("fuse_relu");
111
    bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
112 113
    int groups = ctx.Attr<int>("groups");

114
    bool is_conv3d = strides.size() == 3U;
115
    // TODO(tpatejko): add support for dilation
116
    PADDLE_ENFORCE(
117 118 119 120
        is_conv3d
            ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
                  dilations[2] == 1
            : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
121 122 123 124 125 126 127 128
        "dilation in convolution is not implemented yet");

    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();

    std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
    std::vector<int> weights_tz =
        paddle::framework::vectorize2int(filter->dims());
129
    int g = std::max(groups, 1);
Y
Yihua Xu 已提交
130
    GetWeightsTz(weights_tz, g, is_conv3d);
131 132
    std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());

133
    // Get unique name for storing MKLDNN primitives
J
Jacek Czaja 已提交
134
    const std::string key = platform::ConvMKLDNNHandler::GetHash(
135 136 137 138 139 140
        src_tz, weights_tz, strides, paddings, dilations, groups,
        ctx.op().Output("Output"));
    const std::string key_conv_pd = key + "@conv_pd";

    std::vector<primitive> pipeline;

141 142
    auto src_format = input->format();
    mkldnn::memory::format weights_format =
Y
Yihua Xu 已提交
143
        GetWeightsFormat(filter->format(), g, is_conv3d);
144

145
    auto user_src_md = platform::MKLDNNMemDesc(
146
        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
147
    auto user_weights_md = platform::MKLDNNMemDesc(
148
        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
149 150 151 152 153

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
154 155 156 157
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
        platform::data_format_to_memory_format(data_format);

158 159 160 161 162 163 164
    weights_format = mkldnn::memory::format::any;
    // Check the format for user's special output
    if (chosen_memory_format != mkldnn::memory::format::any) {
      if (is_conv3d) {
        chosen_memory_format =
            platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
      }
165 166
    }

167
    auto src_md = platform::MKLDNNMemDesc(
168
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
169
    auto weights_md = platform::MKLDNNMemDesc(
170
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
171 172
    std::vector<int> bias_tz;  // TODO(mgallus): avoid empty vector creation.
                               // Currently used whenever bias is != nullptr.
173
    auto dst_md = platform::MKLDNNMemDesc(
174
        dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
175 176

    // create a conv primitive descriptor and save it for usage in backward
177
    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
178 179
    auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
                                 : mkldnn::prop_kind::forward_training;
180 181 182 183
    if (bias) {
      bias_tz = paddle::framework::vectorize2int(bias->dims());
      auto bias_md = platform::MKLDNNMemDesc(
          bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
184 185 186
      conv_pd = ConvFwdPrimitiveDesc(
          src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
          fuse_relu, fuse_residual_conn, fwd_prop_kind);
187
    } else {
188 189 190
      conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
                                     paddings, mkldnn_engine, fuse_relu,
                                     fuse_residual_conn, fwd_prop_kind);
191
    }
192
    // Save conv_pd/src_memory/weights_memory for backward pass
193
    if (!is_test) dev_ctx.SetBlob(key_conv_pd, conv_pd);
194

J
Jacek Czaja 已提交
195
    platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
196

197 198 199 200 201 202
    // create mkldnn memory from input tensors (data/weights)
    auto user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
        user_weights_md, to_void_cast<T>(filter_data));

203 204 205 206 207
    // create reorder primitive if the input format is not the preferred one
    auto src_memory_p =
        handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
    auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
        user_weights_memory_p, pipeline, is_test);
208 209

    std::shared_ptr<mkldnn::memory> dst_memory_p;
210

211
    if (fuse_residual_conn) {
212 213
      auto residual_param = ctx.Input<Tensor>("ResidualData");
      auto residual_param_data = residual_param->data<T>();
214

215 216 217 218 219 220
      PADDLE_ENFORCE(
          residual_param_data != nullptr,
          "Provide data if you want MKLDNN conv+elementwise_add fusion");
      PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
                        "Output and elementwise parameter need to have the "
                        "same dimension sizes");
221

222
      if (residual_param->format() != handler.GetDstFormat()) {
Y
Yu Yang 已提交
223 224 225
        auto output_data = output->mutable_data<T>(
            ctx.GetPlace(), ::paddle::memory::Allocator::kDefault,
            handler.GetDstMemorySize());
226 227 228 229 230 231 232 233 234
        auto residual_data_tz =
            paddle::framework::vectorize2int(residual_param->dims());
        auto residual_data_type =
            paddle::framework::ToMKLDNNDataType(residual_param->type());

        auto user_residual_md = platform::MKLDNNMemDesc(
            residual_data_tz, residual_data_type, residual_param->format());
        auto user_residual_memory_p = handler.AcquireResidualDataMemory(
            user_residual_md, to_void_cast<T>(residual_param_data));
235 236 237

        dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
            user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
238 239
      } else {
        output->ShareDataWith(*residual_param);
240 241 242
        auto output_data = output->mutable_data<T>(ctx.GetPlace());
        dst_memory_p =
            handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
243
      }
244
    } else {
245 246 247
      auto output_data = output->mutable_data<T>(
          ctx.GetPlace(), paddle::memory::Allocator::kDefault,
          handler.GetDstMemorySize());
248 249
      dst_memory_p =
          handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
250
    }
251 252

    // create convolution op primitive
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    std::shared_ptr<mkldnn::convolution_forward> conv_p;
    if (bias) {
      const T* bias_data = bias->data<T>();
      auto user_bias_md = platform::MKLDNNMemDesc(
          {bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
      auto user_bias_memory_p =
          handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));

      auto bias_memory_p =
          handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
      conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
                                          bias_memory_p, dst_memory_p);
    } else {
      conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
                                          dst_memory_p);
    }
269 270

    // push primitive to stream and wait until it's executed
271
    pipeline.push_back(*conv_p);
272 273 274
    stream(stream::kind::eager).submit(pipeline).wait();

    output->set_layout(DataLayout::kMKLDNN);
275
    output->set_format(GetMKLDNNFormat(*dst_memory_p));
276
  }
277

278
 private:
279
  mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
280
                                       bool fuse_residual_conn) const {
M
Michal Gallus 已提交
281 282
    mkldnn::primitive_attr conv_attr;
    mkldnn::post_ops post_operations;
283
    // Fusion with Elementwise layer relies on adding a sum post-operation with
284 285 286 287 288
    // the scale parameter. It is assumed that when fuse_residual_connection is
    // true, the output tensor contains the data coming from residual
    // connection. The result of this post_op is:
    // Output = scale * Output + Conv_Out.
    if (fuse_residual_conn) {
289 290 291 292 293 294 295 296 297 298 299
      post_operations.append_sum(1.0f);
    }
    // Fusion with ReLU layer is executed through the PostOps feature. Create a
    // PostOps object and configure it to execute an eltwise relu operation.
    if (fuse_relu) {
      constexpr float scale = 1.0f;
      constexpr float negative_slope = 0.0f;
      constexpr float placeholder = 0.0f;
      post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
                                     negative_slope, placeholder);
    }
M
Michal Gallus 已提交
300 301 302 303
    conv_attr.set_post_ops(post_operations);
    return conv_attr;
  }

304 305 306 307
  std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
  ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
                       const memory::desc& dst, const std::vector<int>& strides,
                       const std::vector<int>& paddings,
308
                       const mkldnn::engine& engine, const bool fuse_relu,
309 310
                       const bool fuse_residual_conn,
                       mkldnn::prop_kind fwd_prop_kind) const {
311 312
    memory::dims stride_dims = strides;
    memory::dims padding_dims = paddings;
313

314
    auto conv_desc = mkldnn::convolution_forward::desc(
315 316
        fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
        stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
317

318 319
    mkldnn::primitive_attr conv_attr =
        CreatePostOps(fuse_relu, fuse_residual_conn);
M
Michal Gallus 已提交
320 321 322

    auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
        conv_desc, conv_attr, engine);
323

324 325
    return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
        p_conv_pd);
326
  }
327 328 329 330 331 332

  std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
  ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
                       const memory::desc& bias, const memory::desc& dst,
                       const std::vector<int>& strides,
                       const std::vector<int>& paddings,
333
                       const mkldnn::engine& engine, const bool fuse_relu,
334 335
                       const bool fuse_residual_conn,
                       mkldnn::prop_kind fwd_prop_kind) const {
336 337
    memory::dims stride_dims = strides;
    memory::dims padding_dims = paddings;
338 339

    auto conv_desc = mkldnn::convolution_forward::desc(
340 341
        fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
        stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
342

343 344
    mkldnn::primitive_attr conv_attr =
        CreatePostOps(fuse_relu, fuse_residual_conn);
M
Michal Gallus 已提交
345 346 347

    auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
        conv_desc, conv_attr, engine);
348 349 350 351

    return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
        p_conv_pd);
  }
352 353 354
};

template <typename T>
355
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
356 357 358 359 360
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");

361 362
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
363 364 365 366 367 368 369 370 371 372
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    const Tensor* input = ctx.Input<Tensor>("Input");
    const Tensor* filter = ctx.Input<Tensor>("Filter");
    const Tensor* output = ctx.Input<Tensor>("Output");
    const Tensor* output_grad =
        ctx.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));

373 374 375 376 377 378 379 380 381 382 383 384 385
    PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
                       input->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input tensor");
    PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
                       filter->format() != memory::format::format_undef,
                   "Wrong layout/format set for Filter tensor");
    PADDLE_ENFORCE(output->layout() == DataLayout::kMKLDNN &&
                       output->format() != memory::format::format_undef,
                   "Wrong layout/format set for Output tensor");
    PADDLE_ENFORCE(output_grad->layout() == DataLayout::kMKLDNN &&
                       output_grad->format() != memory::format::format_undef,
                   "Wrong layout/format set for output_grad tensor");

386 387 388 389
    PADDLE_ENFORCE(
        !ctx.Attr<bool>("is_test"),
        "is_test attribute should be set to False in training phase.");

390 391 392 393
    if (!input_grad && !filter_grad) return;

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
394 395
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int groups = ctx.Attr<int>("groups");
396

397
    bool is_conv3d = strides.size() == 3U;
398 399 400 401 402 403 404 405 406
    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();
    const T* output_grad_data = output_grad->data<T>();
    T* input_grad_data = nullptr;
    T* filter_grad_data = nullptr;

    std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
    std::vector<int> weights_tz =
        paddle::framework::vectorize2int(filter->dims());
407
    int g = std::max(groups, 1);
Y
Yihua Xu 已提交
408
    GetWeightsTz(weights_tz, g, is_conv3d);
409 410
    std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());

411 412
    auto src_format = input->format();
    mkldnn::memory::format weights_format =
Y
Yihua Xu 已提交
413
        GetWeightsFormat(filter->format(), g, is_conv3d);
414

415
    // Get an unique name from "argument" name of "Output" variable
J
Jacek Czaja 已提交
416
    // as well as attributes of primitive to be created
417
    // This name will be used as key when saving info into device context
J
Jacek Czaja 已提交
418 419 420
    const std::string key = platform::ConvMKLDNNHandler::GetHash(
        src_tz, weights_tz, strides, paddings, dilations, groups,
        ctx.op().Input("Output"));
421 422

    const std::string key_conv_pd = key + "@conv_pd";
423
    std::vector<primitive> pipeline;
424

425 426
    // Create user memory descriptors
    auto user_src_md = platform::MKLDNNMemDesc(
427
        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
428
    auto user_weights_md = platform::MKLDNNMemDesc(
429
        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
430 431
    auto user_diff_dst_md = platform::MKLDNNMemDesc(
        {dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
432 433 434 435 436

    /* create memory descriptor for conv backward without specified format
     * ('any') which lets a primitive (conv backward in this case) choose
     * the memory format preferred for best performance
     */
437 438 439 440
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
        platform::data_format_to_memory_format(data_format);

441 442 443 444 445 446 447
    weights_format = mkldnn::memory::format::any;
    // Check the format for user's special output
    if (chosen_memory_format != mkldnn::memory::format::any) {
      if (is_conv3d) {
        chosen_memory_format =
            platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
      }
448 449
    }

450
    auto src_md = platform::MKLDNNMemDesc(
451
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
452
    auto diff_src_md = platform::MKLDNNMemDesc(
453
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
454
    auto weights_md = platform::MKLDNNMemDesc(
455
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
456
    auto diff_weights_md = platform::MKLDNNMemDesc(
457
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
458
    auto diff_dst_md = platform::MKLDNNMemDesc(
459
        dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
460

461
    // Retrieve conv_pd from device context
462 463 464
    auto conv_pd =
        std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
            dev_ctx.GetBlob(key_conv_pd));
465 466 467
    PADDLE_ENFORCE(conv_pd != nullptr,
                   "Fail to find conv_pd in device context");

468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    // create backward convolution weights primitive descriptor
    auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
        mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md,
        strides, paddings, paddings, mkldnn::padding_kind::zero);
    auto conv_bwd_weights_pd =
        std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
            conv_bwd_weights_desc, mkldnn_engine, *conv_pd);

    // create backward convolution data primitive descriptor
    auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
        mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
        strides, paddings, paddings, mkldnn::padding_kind::zero);
    auto conv_bwd_data_pd =
        std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
            conv_bwd_data_desc, mkldnn_engine, *conv_pd);

J
Jacek Czaja 已提交
484 485 486
    platform::ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd,
                                        conv_bwd_weights_pd, dev_ctx,
                                        mkldnn_engine, key);
487 488 489 490 491 492 493 494 495

    // create mkldnn memory from input tensors (data/weights)
    auto user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
        user_weights_md, to_void_cast<T>(filter_data));
    auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
        user_diff_dst_md, to_void_cast<T>(output_grad_data));

496 497
    // create backward conv primitive for weights
    if (filter_grad) {
498 499
      auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
          user_src_memory_p, pipeline);
500

501 502 503 504
      auto diff_dst_memory_4filter_p =
          handler.AcquireDiffDstMemoryFromWeightsPrimitive(
              user_diff_dst_memory_p, pipeline);

505
      const size_t size = handler.GetDiffWeightsMemorySize();
Y
Yu Yang 已提交
506 507
      filter_grad_data = filter_grad->mutable_data<T>(
          ctx.GetPlace(), paddle::memory::Allocator::kDefault, size);
508

509 510 511 512 513 514 515 516 517
      auto diff_weights_memory_p =
          handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
              reinterpret_cast<void*>(filter_grad_data));

      auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights(
          src_memory_p, diff_dst_memory_4filter_p, diff_weights_memory_p);

      // push primitive to stream and wait until it's executed
      pipeline.push_back(*conv_bwd_weights_p);
518 519

      filter_grad->set_layout(DataLayout::kMKLDNN);
520
      filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
521 522 523
    }

    if (input_grad) {
524 525 526 527 528 529 530
      auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
          user_weights_memory_p, pipeline);

      auto diff_dst_memory_4data_p =
          handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
                                                        pipeline);

531
      const size_t size = handler.GetDiffSourceMemorySize();
Y
Yu Yang 已提交
532 533
      input_grad_data = input_grad->mutable_data<T>(
          ctx.GetPlace(), paddle::memory::Allocator::kDefault, size);
534

535 536 537 538 539 540 541
      auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
          reinterpret_cast<void*>(input_grad_data));

      auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData(
          diff_dst_memory_4data_p, weights_memory_p, diff_src_memory_p);

      pipeline.push_back(*conv_bwd_data_p);
542 543

      input_grad->set_layout(DataLayout::kMKLDNN);
544
      input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
545
    }
546
    stream(stream::kind::eager).submit(pipeline).wait();
547 548 549 550 551 552 553 554
  }  // Compute()
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

X
Xin Pan 已提交
555 556 557 558 559 560 561 562 563
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNOpKernel<float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float>);
564 565 566 567 568 569 570 571 572 573

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNOpKernel<float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float>);