MkldnnLayer.cpp 3.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MkldnnLayer.h"

// using namespace mkldnn;  // NOLINT
using mem = mkldnn::memory;  // NOLINT
typedef mem::format format;
typedef mkldnn::inner_product_forward fc_fwd;
typedef mkldnn::inner_product_backward_weights fc_bwdWgt;
typedef mkldnn::inner_product_backward_data fc_bwdData;

namespace paddle {

bool MkldnnLayer::init(const LayerMap& layerMap,
                       const ParameterMap& parameterMap) {
  CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
                          << "Please set WITH_MKLDNN=ON";
  // TODO(TJ): deivecId
  return Layer::init(layerMap, parameterMap);
}

void MkldnnLayer::resetForwardFC(int bs,
                                 int ic,
                                 int ih,
                                 int iw,
                                 real* botData,
                                 int oc,
                                 real* topData,
                                 real* wgtData,
                                 real* biasData) {
  bool hasSpatial = ih == 1 && iw == 1 ? false : true;
  engine_ = CpuEngine::Instance().getEngine();

  mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
                               : createMD({bs, ic}, format::nc);
  mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
                               : createMD({oc, ic}, format::oi);
  mem::desc biasMD = biasData != NULL ? createMD({oc}, format::x)
                                      : createMD({}, format::format_undef);
  mem::desc topMD = createMD({bs, oc}, format::nc);

  mkldnn::prop_kind pk = mkldnn::prop_kind::forward;
  fc_fwd::desc fwdDesc = biasData != NULL
                             ? fc_fwd::desc(pk, botMD, wgtMD, biasMD, topMD)
                             : fc_fwd::desc(pk, botMD, wgtMD, topMD);
  fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);

  mem bot = mem(mem::primitive_desc(botMD, engine_), botData);
  mem wgt = mem(mem::primitive_desc(wgtMD, engine_), wgtData);
  mem top = mem(mem::primitive_desc(topMD, engine_), topData);

  if (biasData != NULL) {
    mem bias = mem(mem::primitive_desc(biasMD, engine_), biasData);
    fwd_.reset(new fc_fwd(fwdPD, bot, wgt, bias, top));
  } else {
    fwd_.reset(new fc_fwd(fwdPD, bot, wgt, top));
  }
  pipelineFwd_.clear();
  pipelineFwd_.push_back(*fwd_);
}

void MkldnnLayer::mkldnnForwardFC(int bs,
                                  int ic,
                                  int ih,
                                  int iw,
                                  real* botData,
                                  int oc,
                                  real* topData,
                                  real* wgtData,
                                  real* biasData) {
  // if input size changed, reset it
  resetForwardFC(bs, ic, ih, iw, botData, oc, topData, wgtData, biasData);

  // just forward
  // update botdata
  stream_->submit(pipelineFwd_);
}

mem::desc MkldnnLayer::createMD(mem::dims dims,
                                mem::format fmt,
                                mem::data_type type) {
  // TODO(TJ): isFmtSuppoted(fmt)
  return mem::desc(dims, type, fmt);
}

}  // namespace paddle