/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "MkldnnLayer.h" // using namespace mkldnn; // NOLINT using mem = mkldnn::memory; // NOLINT typedef mem::format format; typedef mkldnn::inner_product_forward fc_fwd; typedef mkldnn::inner_product_backward_weights fc_bwdWgt; typedef mkldnn::inner_product_backward_data fc_bwdData; namespace paddle { bool MkldnnLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn." << "Please set WITH_MKLDNN=ON"; // TODO(TJ): deivecId return Layer::init(layerMap, parameterMap); } void MkldnnLayer::resetForwardFC(int bs, int ic, int ih, int iw, real* botData, int oc, real* topData, real* wgtData, real* biasData) { bool hasSpatial = ih == 1 && iw == 1 ? false : true; engine_ = CpuEngine::Instance().getEngine(); mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) : createMD({bs, ic}, format::nc); mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw) : createMD({oc, ic}, format::oi); mem::desc biasMD = biasData != NULL ? createMD({oc}, format::x) : createMD({}, format::format_undef); mem::desc topMD = createMD({bs, oc}, format::nc); mkldnn::prop_kind pk = mkldnn::prop_kind::forward; fc_fwd::desc fwdDesc = biasData != NULL ? fc_fwd::desc(pk, botMD, wgtMD, biasMD, topMD) : fc_fwd::desc(pk, botMD, wgtMD, topMD); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); mem bot = mem(mem::primitive_desc(botMD, engine_), botData); mem wgt = mem(mem::primitive_desc(wgtMD, engine_), wgtData); mem top = mem(mem::primitive_desc(topMD, engine_), topData); if (biasData != NULL) { mem bias = mem(mem::primitive_desc(biasMD, engine_), biasData); fwd_.reset(new fc_fwd(fwdPD, bot, wgt, bias, top)); } else { fwd_.reset(new fc_fwd(fwdPD, bot, wgt, top)); } pipelineFwd_.clear(); pipelineFwd_.push_back(*fwd_); } void MkldnnLayer::mkldnnForwardFC(int bs, int ic, int ih, int iw, real* botData, int oc, real* topData, real* wgtData, real* biasData) { // if input size changed, reset it resetForwardFC(bs, ic, ih, iw, botData, oc, topData, wgtData, biasData); // just forward // update botdata stream_->submit(pipelineFwd_); } void MkldnnLayer::resetBackwardFC(int bs, int ic, int ih, int iw, real* botDiff, real* botData, int oc, real* topDiff, real* wgtDiff, real* wgtData, real* biasDiff) { bool hasSpatial = ih == 1 && iw == 1 ? false : true; engine_ = CpuEngine::Instance().getEngine(); // backward weight mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) : createMD({bs, ic}, format::nc); mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw) : createMD({oc, ic}, format::oi); mem::desc topMD = createMD({bs, oc}, format::nc); mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x) : createMD({}, format::format_undef); fc_fwd::desc fwdDesc = fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); fc_bwdWgt::desc bwdWgtDesc = biasDiff != NULL ? fc_bwdWgt::desc(botMD, wgtMD, biasMD, topMD) : fc_bwdWgt::desc(botMD, wgtMD, topMD); fc_bwdWgt::primitive_desc bwdWgtPD = fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); mem botVal = mem(mem::primitive_desc(botMD, engine_), botData); mem wgtGrad = mem(mem::primitive_desc(wgtMD, engine_), wgtDiff); mem topGrad = mem(mem::primitive_desc(topMD, engine_), topDiff); if (biasDiff != NULL) { mem biasGrad = mem(mem::primitive_desc(biasMD, engine_), biasDiff); bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad, biasGrad)); } else { bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad)); } pipelineBwd_.clear(); pipelineBwd_.push_back(*bwdWgt_); // backward data if (botDiff == NULL) { return; } fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD); fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); mem botGrad = mem(mem::primitive_desc(botMD, engine_), botDiff); mem wgtVal = mem(mem::primitive_desc(wgtMD, engine_), wgtData); bwdData_.reset(new fc_bwdData(bwdDataPD, topGrad, wgtVal, botGrad)); pipelineBwd_.push_back(*bwdData_); } void MkldnnLayer::mkldnnBackwardFC(int bs, int ic, int ih, int iw, real* botDiff, real* botData, int oc, real* topDiff, real* wgtDiff, real* wgtData, real* biasDiff) { // if input size changed, reset it resetBackwardFC(bs, ic, ih, iw, botDiff, botData, oc, topDiff, wgtDiff, wgtData, biasDiff); // just forward // update botdata stream_->submit(pipelineBwd_); } mem::desc MkldnnLayer::createMD(mem::dims dims, mem::format fmt, mem::data_type type) { // TODO(TJ): isFmtSuppoted(fmt) return mem::desc(dims, type, fmt); } } // namespace paddle