diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp deleted file mode 100644 index 2b7bef0a757d7c706be3815c539b036b094596cf..0000000000000000000000000000000000000000 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "ExpandConvBaseLayer.h" - -#include "paddle/utils/Logging.h" -namespace paddle { - -bool ExpandConvBaseLayer::init(const LayerMap &layerMap, - const ParameterMap ¶meterMap) { - /* Initialize the basic convolutional parent class */ - ConvBaseLayer::init(layerMap, parameterMap); - - int index = 0; - for (auto &inputConfig : config_.inputs()) { - const ConvConfig &conf = inputConfig.conv_conf(); - /* Consistent caffe mode for multiple input */ - caffeMode_ = conf.caffe_mode(); - - // create a new weight - size_t height, width; - height = filterPixels_[index] * filterChannels_[index]; - width = (!isDeconv_) ? numFilters_ : channels_[index]; - CHECK_EQ(parameters_[index]->getSize(), width * height); - Weight *w = new Weight(height, width, parameters_[index]); - weights_.emplace_back(w); - index++; - } - if (biasParameter_.get()) { - if (sharedBiases_) { - CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); - biases_ = - std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); - } else { - biases_ = - std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); - } - } - getOutputSize(); - - return true; -} - -size_t ExpandConvBaseLayer::getOutputSize() { - CHECK_NE(inputLayers_.size(), 0UL); - size_t layerSize = ConvBaseLayer::calOutputSize(); - return layerSize; -} - -void ExpandConvBaseLayer::addSharedBias() { - size_t mapW = getOutputSize() / numFilters_; - size_t mapH = getOutputValue()->getElementCnt() / mapW; - MatrixPtr out = - Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_); - - Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_); - - out->transpose(transOutValue_, false); // false means no memory allocation - transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_, - numFilters_); - - MatrixPtr bias = Matrix::create(biases_->getW()->getData(), - 1, - biases_->getW()->getElementCnt(), - false, - useGpu_); - transOutValue_->addBias(*bias, 1.0f); - - transOutValue_->reshape(mapW, mapH); - transOutValue_->transpose(out, false); // false means no memory allocation - - out->clear(); - bias->clear(); -} - -void ExpandConvBaseLayer::addUnsharedBias() { - MatrixPtr outValue = getOutputValue(); - MatrixPtr bias = Matrix::create(biases_->getW()->getData(), - 1, - biases_->getW()->getElementCnt(), - false, - useGpu_); - outValue->addBias(*bias, 1.0f); -} - -void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) { - size_t mapW = getOutputSize() / numFilters_; - size_t mapH = v->getElementCnt() / mapW; - MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_); - - Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_); - - vTmp->transpose(transOutValue_, false); // false means no memory allocation - transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_, - numFilters_); - biases->collectBias(*transOutValue_, 1.0f); -} - -void ExpandConvBaseLayer::bpropBiases(MatrixPtr v) { - MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(), - 1, - biases_->getWGrad()->getElementCnt(), - false, - useGpu_); - if (sharedBiases_) { - bpropSharedBias(biases, v); - } else { - biases->collectBias(*v, 1.0f); - } - biases->clear(); -} - -} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.h b/paddle/gserver/layers/ExpandConvBaseLayer.h deleted file mode 100644 index 01c699d2344443a1887ec0b5005125f617cbe279..0000000000000000000000000000000000000000 --- a/paddle/gserver/layers/ExpandConvBaseLayer.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "ConvBaseLayer.h" -#include "paddle/math/Matrix.h" - -namespace paddle { - -/** - * @brief A subclass of ConvBaseLayer that is a superclass of both - * ExpandConvLayer and ExpandConvTransLayer - */ -class ExpandConvBaseLayer : public ConvBaseLayer { -protected: - /// The transpose of output, which is an auxiliary matrix. - MatrixPtr transOutValue_; - -public: - explicit ExpandConvBaseLayer(const LayerConfig& config) - : ConvBaseLayer(config) {} - - ~ExpandConvBaseLayer() {} - - bool init(const LayerMap& layerMap, - const ParameterMap& parameterMap) override; - - size_t getOutputSize(); - - /** - * Add shared bias. - */ - void addSharedBias(); - - /** - * Add unshared bias. - */ - void addUnsharedBias(); - - void bpropSharedBias(MatrixPtr biases, MatrixPtr v); - void bpropBiases(MatrixPtr v); -}; - -} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 20de475fc3f6b6f3c05ac26bea8363daff0cf110..48dfcb49a4c2c46891bb5236fc1f8e644c03f327 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -36,7 +36,36 @@ inline bool isDepthwiseConv(int channels, int groups) { bool ExpandConvLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { /* Initialize the basic convolutional parent class */ - ExpandConvBaseLayer::init(layerMap, parameterMap); + ConvBaseLayer::init(layerMap, parameterMap); + + int index = 0; + for (auto &inputConfig : config_.inputs()) { + const ConvConfig &conf = inputConfig.conv_conf(); + /* Consistent caffe mode for multiple input */ + caffeMode_ = conf.caffe_mode(); + + // create a new weight + size_t height, width; + height = filterPixels_[index] * filterChannels_[index]; + width = (!isDeconv_) ? numFilters_ : channels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + index++; + } + + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = std::unique_ptr( + new Weight(1, numFilters_, biasParameter_, 0)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_, 0)); + } + } + + getOutputSize(); size_t numInputs = config_.inputs_size(); inputShape_.resize(numInputs); @@ -108,6 +137,12 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, return true; } +size_t ExpandConvLayer::getOutputSize() { + CHECK_NE(inputLayers_.size(), 0UL); + size_t layerSize = ConvBaseLayer::calOutputSize(); + return layerSize; +} + // i is the index of input layers #define BACKWARD_INPUT(i, inputs, outputs) \ backward_[2 * i]->calc(inputs, outputs) @@ -155,11 +190,7 @@ void ExpandConvLayer::forward(PassType passType) { /* add the bias-vector */ if (biases_.get()) { - if (sharedBiases_) { - addSharedBias(); - } else { - addUnsharedBias(); - } + output_.value->addBias(*biases_->getW(), 1.0, sharedBiases_); } /* activation */ @@ -171,7 +202,7 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) { MatrixPtr outGrad = getOutputGrad(); if (biases_ && biases_->getWGrad()) { - bpropBiases(outGrad); + biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBiases_); /* Increasing the number of gradient */ biases_->getParameterPtr()->incUpdate(callback); } diff --git a/paddle/gserver/layers/ExpandConvLayer.h b/paddle/gserver/layers/ExpandConvLayer.h index a1f943d1521547af0f82cec7da8a4efe9037cd71..a0873de19253f2496bc0c2fba550b3199dfc7486 100644 --- a/paddle/gserver/layers/ExpandConvLayer.h +++ b/paddle/gserver/layers/ExpandConvLayer.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "ExpandConvBaseLayer.h" +#include "ConvBaseLayer.h" #include "paddle/math/Matrix.h" namespace paddle { @@ -28,10 +28,9 @@ namespace paddle { * The config file api is img_conv_layer. */ -class ExpandConvLayer : public ExpandConvBaseLayer { +class ExpandConvLayer : public ConvBaseLayer { public: - explicit ExpandConvLayer(const LayerConfig& config) - : ExpandConvBaseLayer(config) {} + explicit ExpandConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {} ~ExpandConvLayer() {} @@ -41,6 +40,8 @@ public: void forward(PassType passType) override; void backward(const UpdateCallback& callback) override; + size_t getOutputSize(); + protected: std::vector inputShape_; std::vector filterShape_; diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index f8c06c5f868f8d48a9a222b92315ee0ef2cf265e..9088744beebd25ac105737fe3b012de143c66a7c 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -285,10 +285,9 @@ void MKLDNNConvLayer::resetWgtBiasValue( wgt = MKLDNNMatrix::create(weight_->getW(), pd->weights_primitive_desc()); VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); - bias = nullptr; - if (biases_ && biases_->getW()) { - bias = MKLDNNMatrix::create(biases_->getW(), pd->bias_primitive_desc()); - } + bias = (biases_ && biases_->getW()) + ? MKLDNNMatrix::create(biases_->getW(), pd->bias_primitive_desc()) + : nullptr; } void MKLDNNConvLayer::resetOutValue( @@ -356,6 +355,7 @@ void MKLDNNConvLayer::resetBwdWgtPD( void MKLDNNConvLayer::resetBwdDataPD( std::shared_ptr& pd) { + pd = nullptr; if (inputLayers_[0]->getOutput().grad == nullptr) { return; } @@ -476,6 +476,7 @@ void MKLDNNConvLayer::resetWgtBiasGrad( << "primitive desc of weight grad and value should be equal"; VLOG(MKLDNN_FMTS) << "weight grad format: " << wgt->getFormat(); + bias = nullptr; if (biasVal_ == nullptr) { return; } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index f70343251ad4fbb99f9614618f6d1bff1174f15e..f60e221a6ec2ff513789a24e9f59bb25aef437b5 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -17,9 +17,6 @@ limitations under the License. */ using namespace mkldnn; // NOLINT typedef memory::format format; -typedef inner_product_forward fc_fwd; -typedef inner_product_backward_weights fc_bwdWgt; -typedef inner_product_backward_data fc_bwdData; namespace paddle { @@ -93,35 +90,88 @@ void MKLDNNFcLayer::reshape( printSizeInfo(); } -void MKLDNNFcLayer::resetFwd(std::vector& pipeline, +void MKLDNNFcLayer::resetFwd(std::vector& pipeline, MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { - pipeline.clear(); - bool hasBias = biases_ && biases_->getW(); - const MatrixPtr& wgtVal = weight_->getW(); - const MatrixPtr& biasVal = hasBias ? biases_->getW() : nullptr; - const MatrixPtr& outVal = output_.value; + resetFwdBuffers(in, wgt, bias, out); + + resetFwdPD(fwdPD_, in, wgt, bias, out); + + resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); + + printValueFormatFlow(); +} + +void MKLDNNFcLayer::resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + std::shared_ptr bwdWgtPD; + std::shared_ptr bwdDataPD; + + resetBwdBuffers(in, wgt, bias, out); + + resetBwdWgtPD(bwdWgtPD, wgt, bias, out); + + resetBwdDataPD(bwdDataPD, in, out); + + resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); + + printGradFormatFlow(); +} + +void MKLDNNFcLayer::updateInputData() { + inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); +} +void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { + weight_->getParameterPtr()->incUpdate(callback); + if (biases_ && biases_->getWGrad()) { + biases_->getParameterPtr()->incUpdate(callback); + } +} + +void MKLDNNFcLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + resetInValue(in); + + resetWgtBiasValue(wgt, bias); + + resetOutValue(out); +} + +void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) { if (inputIsOnlyMKLDNN()) { - const MatrixPtr& inVal = getInputValue(0); - in = std::dynamic_pointer_cast(inVal); + const MatrixPtr& dnnIn = getInputValue(0); + in = std::dynamic_pointer_cast(dnnIn); CHECK(in) << "Input should be MKLDNNMatrix"; } else { CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; - const MatrixPtr& inVal = getInputValue(0, CPU_DEVICE); + const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); in = MKLDNNMatrix::create( - inVal, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_); + cpuIn, {bs_, ic_, ih_, iw_}, format::nchw, engine_); } in->downSpatial(); +} + +void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias) { wgt = MKLDNNMatrix::create( - wgtVal, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_); + weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_); wgt->downSpatial(); - bias = hasBias ? MKLDNNMatrix::create(biasVal, {oc_}, format::x, engine_) - : nullptr; - out = MKLDNNMatrix::create(outVal, {bs_, oc_}, format::nc, engine_); + bias = (biases_ && biases_->getW()) + ? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_) + : nullptr; +} + +void MKLDNNFcLayer::resetOutValue(MKLDNNMatrixPtr& out) { + out = MKLDNNMatrix::create(output_.value, {bs_, oc_}, format::nc, engine_); // change original output value to mkldnn output value output_.value = std::dynamic_pointer_cast(out); if (!outputIsOnlyMKLDNN()) { @@ -129,46 +179,59 @@ void MKLDNNFcLayer::resetFwd(std::vector& pipeline, // just share point getOutput(CPU_DEVICE).value->setData(output_.value->getData()); } +} - // create forward handle +void MKLDNNFcLayer::resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr wgt, + MKLDNNMatrixPtr bias, + MKLDNNMatrixPtr out) { + CHECK(in); + CHECK(wgt); + CHECK(out); prop_kind pk = prop_kind::forward; - fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk, - in->getMemoryDesc(), - wgt->getMemoryDesc(), - bias->getMemoryDesc(), - out->getMemoryDesc()) - : fc_fwd::desc(pk, - in->getMemoryDesc(), - wgt->getMemoryDesc(), - out->getMemoryDesc()); - fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); - if (hasBias) { - fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *bias, *out)); + fc_fwd::desc fwdDesc = bias != nullptr ? fc_fwd::desc(pk, + in->getMemoryDesc(), + wgt->getMemoryDesc(), + bias->getMemoryDesc(), + out->getMemoryDesc()) + : fc_fwd::desc(pk, + in->getMemoryDesc(), + wgt->getMemoryDesc(), + out->getMemoryDesc()); + pd.reset(new fc_fwd::primitive_desc(fwdDesc, engine_)); +} + +void MKLDNNFcLayer::resetFwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + + if (bias) { + fwd_.reset(new fc_fwd(*pd, *in, *wgt, *bias, *out)); } else { - fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *out)); + fwd_.reset(new fc_fwd(*pd, *in, *wgt, *out)); } - printValueFormatFlow(); pipeline.push_back(*fwd_); } -void MKLDNNFcLayer::resetBwd(std::vector& pipeline, - MKLDNNMatrixPtr& in, - MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias, - MKLDNNMatrixPtr& out) { - pipeline.clear(); - if (!needResetBwd_) { - return; - } - needResetBwd_ = false; - bool hasBias = biases_ && biases_->getWGrad(); +void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + resetOutGrad(out); + + resetWgtBiasGrad(wgt, bias); - /// backward weight - CHECK(inVal_) << "Should have input value"; - const MatrixPtr& wgtGrad = weight_->getWGrad(); - const MatrixPtr& biasGrad = hasBias ? biases_->getWGrad() : nullptr; + resetInGrad(in); +} +void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) { // TODO(TJ): merge outgrad int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; // for MKLDNN device: @@ -178,66 +241,88 @@ void MKLDNNFcLayer::resetBwd(std::vector& pipeline, // for CPU device: // fc do not need to convert from cpu device since output is always nc format // only need create from cpu device - const MatrixPtr& outGrad = getOutput(device).grad; - out = MKLDNNMatrix::create(outGrad, outVal_->getPrimitiveDesc()); - wgt = MKLDNNMatrix::create(wgtGrad, wgtVal_->getPrimitiveDesc()); - bias = hasBias ? MKLDNNMatrix::create(biasGrad, biasVal_->getPrimitiveDesc()) - : nullptr; - - // create memory primitive desc - fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, - inVal_->getMemoryDesc(), - wgt->getMemoryDesc(), - out->getMemoryDesc()); - fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); - fc_bwdWgt::desc bwdWgtDesc = hasBias - ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), - wgt->getMemoryDesc(), - bias->getMemoryDesc(), - out->getMemoryDesc()) - : fc_bwdWgt::desc(inVal_->getMemoryDesc(), - wgt->getMemoryDesc(), - out->getMemoryDesc()); - fc_bwdWgt::primitive_desc bwdWgtPD = - fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); - - if (hasBias) { - bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt, *bias)); - } else { - bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt)); + CHECK(outVal_); + out = + MKLDNNMatrix::create(getOutput(device).grad, outVal_->getPrimitiveDesc()); +} + +void MKLDNNFcLayer::resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias) { + CHECK(wgtVal_); + wgt = MKLDNNMatrix::create(weight_->getWGrad(), wgtVal_->getPrimitiveDesc()); + + bias = nullptr; + if (biasVal_ == nullptr) { + return; } - pipeline.push_back(*bwdWgt_); + bias = + MKLDNNMatrix::create(biases_->getWGrad(), biasVal_->getPrimitiveDesc()); +} - /// backward data +void MKLDNNFcLayer::resetInGrad(MKLDNNMatrixPtr& in) { + in = nullptr; const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad; if (inGrad == nullptr) { return; } - if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { - // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done - } else { - in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc()); - } - - fc_bwdData::desc bwdDataDesc = fc_bwdData::desc( - inVal_->getMemoryDesc(), wgt->getMemoryDesc(), out->getMemoryDesc()); - fc_bwdData::primitive_desc bwdDataPD = - fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); + // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done + CHECK(inVal_); + in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc()); +} - CHECK(wgtVal_) << "Should have weight memory"; - bwdData_.reset(new fc_bwdData(bwdDataPD, *out, *wgtVal_, *in)); - printGradFormatFlow(); - pipeline.push_back(*bwdData_); +void MKLDNNFcLayer::resetBwdWgtPD( + std::shared_ptr& pd, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + CHECK(inVal_); + fc_bwdWgt::desc bwdWgtDesc = bias ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), + wgt->getMemoryDesc(), + bias->getMemoryDesc(), + out->getMemoryDesc()) + : fc_bwdWgt::desc(inVal_->getMemoryDesc(), + wgt->getMemoryDesc(), + out->getMemoryDesc()); + pd.reset(new fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); } -void MKLDNNFcLayer::updateInputData() { - inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); +void MKLDNNFcLayer::resetBwdDataPD( + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + pd = nullptr; + if (in == nullptr) { + return; + } + CHECK(wgtVal_); + fc_bwdData::desc bwdDataDesc = fc_bwdData::desc( + in->getMemoryDesc(), wgtVal_->getMemoryDesc(), out->getMemoryDesc()); + pd.reset(new fc_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_)); } -void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { - weight_->getParameterPtr()->incUpdate(callback); - if (biases_ && biases_->getWGrad()) { - biases_->getParameterPtr()->incUpdate(callback); +void MKLDNNFcLayer::resetBwdPipeline( + std::vector& pipeline, + std::shared_ptr& bwdWgtPD, + std::shared_ptr& bwdDataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + CHECK(inVal_); + if (bias) { + bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt, *bias)); + } else { + bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt)); + } + pipeline.push_back(*bwdWgt_); + + if (bwdDataPD == nullptr) { + return; } + CHECK(wgtVal_) << "Should have weight memory"; + bwdData_.reset(new fc_bwdData(*bwdDataPD, *out, *wgtVal_, *in)); + pipeline.push_back(*bwdData_); } + } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index 3119f863496df092da13c08bf733f13c42e53780..c76878aafab7e986d2bf478eaba02f2f0aced293 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -18,6 +18,9 @@ limitations under the License. */ #include "mkldnn.hpp" namespace paddle { +typedef mkldnn::inner_product_forward fc_fwd; +typedef mkldnn::inner_product_backward_weights fc_bwdWgt; +typedef mkldnn::inner_product_backward_data fc_bwdData; /** * @brief A subclass of MKLDNNLayer fc layer. @@ -32,6 +35,9 @@ protected: // if has already init the weight bool hasInitedWgt_; + // save forward primitive_desc, which can be used backward + std::shared_ptr fwdPD_; + // fc weight and bias std::unique_ptr weight_; std::unique_ptr biases_; @@ -67,6 +73,59 @@ public: void convertWeightsFromPaddle() override; void convertWeightsToPaddle() override; + +protected: + /** + * Forward functions: reset buffers(input, output, weight and bias), + * reset primitive descriptor, + * reset pipeline. + */ + void resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + void resetInValue(MKLDNNMatrixPtr& in); + void resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias); + void resetOutValue(MKLDNNMatrixPtr& out); + void resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr wgt, + MKLDNNMatrixPtr bias, + MKLDNNMatrixPtr out); + void resetFwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + + /** + * Backward functions: reset buffers(input, output, weight and bias), + * reset primitive descriptor for backward weight, + * reset primitive descriptor for backward data, + * reset pipeline. + */ + void resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + void resetOutGrad(MKLDNNMatrixPtr& out); + void resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias); + void resetInGrad(MKLDNNMatrixPtr& in); + void resetBwdWgtPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + void resetBwdDataPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out); + void resetBwdPipeline(std::vector& pipeline, + std::shared_ptr& bwdWgtPD, + std::shared_ptr& bwdDataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); }; } // namespace paddle diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index 0aa130b4a0d458ad78d5d1330164af9e73b22a44..c843115eb9a5be50d6ff873f1510844228c9d89f 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -66,11 +66,12 @@ public: /** * Create reorder primitive. * Create a mkldnn::reorder handle for converting src MKLDNNMatrix to dst. - * checkData: for whether to check the data handle of src and dst is the same. - * if true, means check it and do not want support inplace reorder; - * otherwise do not check data which means the created reorder - * maybe inplace buffer and do not guarantee the logical is correct - * since not all format or conversion support inplace. + * checkData: whether to check the data handle of src and dst. + * if true, it will check the data and do not allow them equal; + * otherwise, it will not check them, then the reorder created + * may have inplace buffer. + * Do not set false, if you can not guarantee the inplace logical + * would work with your reorder. */ static std::shared_ptr createReorder( const MKLDNNMatrixPtr& src, diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 4a6c6381b0341dd3531aa4c09024530ee67bb4f9..0c813748b2989a8f0c00a359345747242dd21dd8 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -23,10 +23,15 @@ class AccuracyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), - "Input of Inference must be initialized."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("Inference"), + "Input(Inference) of AccuracyOp should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input of Inference must be initialized."); + "Input(Label) of AccuracyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Accuracy"), + "Output(Accuracy) of AccuracyOp should not be null."); + auto *inference = ctx.Input("Inference"); auto *label = ctx.Input("Label"); diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index b43c09d4f09c7f87cc60290bdd2a99cbe46f0d5c..e83c1efeaf897889d18a37a6bd2ca2f8f012db25 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -23,6 +23,13 @@ class AddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of AddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of AddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of AddOp should not be null."); + PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), ctx.Input("Y")->dims(), "Two input of Add Op's dimension must be same."); diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 72fd179354a4be76a37e4571da168d844f7ce384..223bb0ffe6e75ce71919eb5f4cca06bedbb00764 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -25,6 +25,9 @@ class ConcatOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ConcatOp should not be null."); + auto ins = ctx.MultiInput("X"); auto *out = ctx.Output("Out"); size_t axis = static_cast(ctx.Attr("axis")); diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index b2e1ca395dcbb791a6159bc9ab5cb37af2f12e8b..8262a7a5c8c13c86c5f6c123a14fa89696358c57 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -33,7 +33,8 @@ using DDim = framework::DDim; void CondOp::CreateScope(const Scope& scope) const { auto sub_scopes_var = scope.FindVar("SubScopes"); - PADDLE_ENFORCE(sub_scopes_var != nullptr, ""); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var, + "Output(SubScopes) of CondOp should not be null."); auto sub_scopes = sub_scopes_var->GetMutable>(); auto& sub_scope = scope.NewScope(); sub_scopes->push_back(&sub_scope); @@ -41,7 +42,8 @@ void CondOp::CreateScope(const Scope& scope) const { void CondOp::CreateIndexTensor(const Scope& scope) const { auto index_tensors_var = scope.FindVar("IndexTensors"); - PADDLE_ENFORCE(index_tensors_var != nullptr, ""); + PADDLE_ENFORCE_NOT_NULL(index_tensors_var, + "Output(IndexTensors) of CondOp should not be null."); auto& index_tensors = *index_tensors_var->GetMutable>(); index_tensors.push_back(LoDTensor()); @@ -49,7 +51,8 @@ void CondOp::CreateIndexTensor(const Scope& scope) const { void CondOp::InferShape(const Scope& scope) const { auto sub_scopes_var = scope.FindVar("SubScopes"); - PADDLE_ENFORCE_NOT_NULL(sub_scopes_var); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var, + "Output(SubScopes) of CondOp should not be null."); auto& sub_scopes = *sub_scopes_var->GetMutable>(); for (int i = 0; i < 2; ++i) { @@ -63,7 +66,8 @@ void CondOp::InferShape(const Scope& scope) const { // branch CreateIndexTensor(scope); - PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty"); + PADDLE_ENFORCE(!Inputs("Xs").empty(), + "Inputs(Xs) of CondOp can't be empty."); for (auto& input : Inputs("Xs")) { // Create a new tensor in sub-scope for input-type tensor Variable* v = sub_scopes[i]->NewVar(input); @@ -108,13 +112,18 @@ void CondOp::InferShape(const Scope& scope) const { void CondOp::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto* sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var, + "Output(SubScopes) of CondOp should not be null."); auto sub_scopes = sub_scopes_var->Get>(); auto* index_tensors_var = scope.FindVar("IndexTensors"); + PADDLE_ENFORCE_NOT_NULL(index_tensors_var, + "Output(IndexTensors) of CondOp should not be null."); auto index_tensors = index_tensors_var->Get>(); std::string cond_name = Input("Cond"); Variable* cond_var = scope.FindVar(cond_name); - PADDLE_ENFORCE_NOT_NULL(cond_var); + PADDLE_ENFORCE_NOT_NULL(cond_var, + "Input(Cond) of CondOp should not be null."); const LoDTensor* cond = cond_var->GetMutable(); // Step 1: get the true/false index at runtime @@ -171,6 +180,8 @@ void CondOp::Run(const Scope& scope, } // Step 4: merge output results + PADDLE_ENFORCE(!Outputs("Outs").empty(), + "Outputs(Outs) of CondOp can't be empty."); for (int i = 0; i < 2; ++i) { // i= 0/i for True and False branches respectively for (auto& output : Outputs("Outs")) { diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 253b17d8a1b88eccc58fc458ae8274d2bbd1c323..72c446493684246959656dc048e7f0e761665423 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -26,8 +26,16 @@ class CosSimOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { // notnull check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"), + "Output(XNorm) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"), + "Output(YNorm) of CosSimOp should not be null."); // shape check auto x_dims = ctx.Input("X")->dims(); diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index e37c582adbe5b9e728f683d97cc51063ce80c3a2..ee6e975b443691bf71cec904565ced20406f3fba 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -25,8 +25,14 @@ class ElementWiseMulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ElementWiseMulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of ElementWiseMulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of ElementWiseMulOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); auto y_dim = ctx.Input("Y")->dims(); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h index e9ed6791799240039f9af42c1a4339be7126ee65..6d58da580b81b9e0a8ae170eec1a73638b190df8 100644 --- a/paddle/operators/elementwise_mul_op.h +++ b/paddle/operators/elementwise_mul_op.h @@ -13,10 +13,8 @@ limitations under the License. */ #pragma once -#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 0c9734892aac216709d380ec66acadf792761b14..ba7857cc65f6860a6156674c6addc2bfdce21a99 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -23,6 +23,13 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("Src"), + "Input(Src) of FillZerosLikeOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Dst"), + "Output(Dst) of FillZerosLikeOp should not be null."); + ctx.Output("Dst")->Resize( ctx.Input("Src")->dims()); } diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 8883d6d5fed2d8900364cf713a1e8d8b290ef83e..d445b61c1657356f2cdcf1e98d756607de2bd042 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -24,6 +24,13 @@ class GatherOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of GatherOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), + "Input(Index) of GatherOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of GatherOp should not be null."); + int batch_size = ctx.Input("Index")->dims()[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); framework::DDim output_dims(ctx.Input("X")->dims()); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 25b0776a37488e876b0cc88aa3f2aa68e33fb270..c0e161bbc0c5486eb10408e43e6388f1b287abf8 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,8 +43,12 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output("Out"); + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of GaussianRandomOp should not be null."); + + auto* tensor = ctx.Output("Out"); auto dims = Attr>("dims"); std::vector temp; temp.reserve(dims.size()); diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc index 7d9d4fa519d1c690feacbadc5175aeab49082282..b67ca5f6f8d516224e18a5eed497f2bfc680259c 100644 --- a/paddle/operators/identity_op.cc +++ b/paddle/operators/identity_op.cc @@ -42,6 +42,11 @@ class IdentityOp : public NetOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, + "Input(X) of IdentityOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of IdentityOp should not be null."); + AppendOp(framework::OpRegistry::CreateOp( "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}}, {{"scale", static_cast(1)}})); diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index b3d15f1ec99813d242c86c99faa7385795eef3b1..07f6dfabca5879e3de6004e59d2e87f7fa68d66c 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -22,10 +22,17 @@ class LookupTableOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &context) const override { - auto table_t = context.Input("W"); - auto ids_t = context.Input("Ids"); - auto output_t = context.Output("Out"); + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("W"), + "Input(W) of LookupTableOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), + "Input(Ids) of LookupTableOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of LookupTableOp should not be null."); + + auto table_t = ctx.Input("W"); + auto ids_t = ctx.Input("Ids"); + auto output_t = ctx.Output("Out"); output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); } diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 3e523d31b682d70825d50c2a57b6e98cbf29dcd3..7d7eeb59a23435036dc33c1e4fe6dd1c4a1a2f62 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -24,7 +24,9 @@ class MeanOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of MeanOp must be initialized."); + "Input(X) of MeanOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of MeanOp should not be null."); ctx.Output("Out")->Resize({1}); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 8a583f24edf40ce805ca216ab014bd169f9236df..ecf8a6f7795314e2475bb9546b55b8f354b96366 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -27,6 +27,13 @@ class MinusOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of MinusOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of MinusOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of MinusOp should not be null."); + auto *left_tensor = ctx.Input("X"); auto *right_tensor = ctx.Input("Y"); @@ -77,8 +84,6 @@ class MinusGradOp : public NetOp { } // namespace operators } // namespace paddle -USE_OP(scale); -USE_NO_KERNEL_OP(identity); namespace ops = paddle::operators; REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, ops::MinusGradOp); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 015e13de9a09bcd1931ccf91413b6a3f484f82bb..b6d320b415e02549e85cb36ab517b0b5433887d5 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -26,6 +26,13 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of MulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of MulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of MulOp should not be null."); + auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); int x_num_col_dims = Attr("x_num_col_dims"); diff --git a/paddle/operators/onehot_cross_entropy_op.cc b/paddle/operators/onehot_cross_entropy_op.cc index a9baada1cd4cd3af793bf1b0af7b029417e62b08..f38be3549f3c5d2443f61739fc32cdca74197649 100644 --- a/paddle/operators/onehot_cross_entropy_op.cc +++ b/paddle/operators/onehot_cross_entropy_op.cc @@ -23,6 +23,16 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), + "Input(X) of OnehotCrossEntropyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("label"), + "Input(label) of OnehotCrossEntropyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Y"), + "Output(Y) of OnehotCrossEntropyOp should not be null."); + auto *X = ctx.Input("X"); auto *label = ctx.Input("label"); diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 6cf7bd6f35b7592c41983efd75c1628043070687..a0b1c6b631d97a40d774f7d2ff9550fda9c32db4 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -25,6 +25,11 @@ class PadOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of PadOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of PadOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); auto paddings = Attr>("paddings"); PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index d2817020921dfd3c044922de6a0f2ae0307936bd..0d05e344148c68f5625dd819ec59c5991892e4ce 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -28,7 +28,11 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ReshapeOp should not be null."); + auto shape = ctx.Attr>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); for (auto dim : shape) { diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index c6101685a3205a7a7459347ea5b0cc8487656550..2a3fd3be941d91aaa6b014df91d3025f07767577 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -25,6 +25,13 @@ class RowwiseAddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), + "Input(b) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of RowwiseAddOp should not be null."); + auto x_dims = ctx.Input("X")->dims(); auto b_dims = ctx.Input("b")->dims(); PADDLE_ENFORCE_GT( diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 35e6b70ba94a645da2b99093b2354acfb0ef2771..d1f42e8662537d35e17429f9d436fdc0e5a1dc11 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -27,6 +27,11 @@ class ScaleOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ScaleOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ScaleOp should not be null."); + auto *in = ctx.Input("X"); auto *out = ctx.Output("Out"); out->Resize(in->dims()); diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 0f7510983e2e34485110719c92d26cbc78cd850c..8820262732327306f4f807702751708bd1e2aa36 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -24,6 +24,15 @@ class ScatterOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ref"), + "Input(Ref) of ScatterOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), + "Input(Index) of ScatterOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Updates"), + "Input(Updates) of ScatterOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ScatterOp should not be null."); + PADDLE_ENFORCE_EQ(ctx.Input("Index")->dims().size(), 1, "Update Index should be 1-D."); PADDLE_ENFORCE_EQ(ctx.Input("Ref")->dims().size(), diff --git a/paddle/operators/sequence_avg_pool_op.cc b/paddle/operators/sequence_avg_pool_op.cc index c15a5833deba2e198f6cb724bda7e3306c56e461..9815b8f3a8d813959949bbfedc79f404721a8216 100644 --- a/paddle/operators/sequence_avg_pool_op.cc +++ b/paddle/operators/sequence_avg_pool_op.cc @@ -23,9 +23,12 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of SequenceAvgPoolOp" - "must be initialized."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), "Input(X) of SequenceAvgPoolOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of SequenceAvgPoolOp should not be null."); + auto* x = ctx.Input("X"); auto dims = x->dims(); auto lod = x->lod(); @@ -60,7 +63,9 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Gradient of Out should not be null"); + "Gradient of Out should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "The input X should not be null."); auto og_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto x_dims = ctx.Input("X")->dims(); diff --git a/paddle/operators/sequence_avg_pool_op.h b/paddle/operators/sequence_avg_pool_op.h index 6e343b87e2938399409498407ac46b2416dc2231..ebe0956344eb71d0fb2836f1b4a989ac546d9f78 100644 --- a/paddle/operators/sequence_avg_pool_op.h +++ b/paddle/operators/sequence_avg_pool_op.h @@ -21,6 +21,9 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +template +using EigenVector = framework::EigenVector; template using EigenMatrix = framework::EigenMatrix; @@ -43,8 +46,8 @@ class SequenceAvgPoolKernel : public framework::OpKernel { static_cast(lod[0][i + 1])); Tensor out_t = out->Slice(i, i + 1); int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); - auto in_e = EigenMatrix::From(in_t, {h, w}); - auto out_e = EigenMatrix::From(out_t, {h, w}); + auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); + auto out_e = EigenVector::Flatten(out_t); out_e.device(place) = in_e.mean(Eigen::array({{0}})); } } @@ -54,9 +57,9 @@ template class SequenceAvgPoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Output("X"); - auto* in_g = context.Output(framework::GradVarName("X")); + auto* in = context.Input("X"); auto* out_g = context.Input(framework::GradVarName("Out")); + auto* in_g = context.Output(framework::GradVarName("X")); auto dims = in->dims(); auto lod = in->lod(); @@ -71,7 +74,7 @@ class SequenceAvgPoolGradKernel : public framework::OpKernel { int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); - Eigen::DSizes bcast(h, w); + Eigen::DSizes bcast(h, 1); in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); } } diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 7997bf690710b3675f4014790db8cc7fc06946d3..1232e64c7f0132b9ea19b3d7e1ebe9531e1e25a5 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -23,6 +23,13 @@ class SGDOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("param"), + "Input(param) of SGDOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("grad"), + "Input(grad) of SGDOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("param_out"), + "Output(param_out) of SGDOp should not be null."); + PADDLE_ENFORCE_EQ(ctx.Input("param")->dims(), ctx.Input("grad")->dims(), "Two input of SGD Op's dimension must be same."); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index de6a1ba77382306923ee238e5cf309c791f9f216..992b19965e0ca9ce7dba1b8b3c5b7780af06eb45 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -23,6 +23,11 @@ class SigmoidOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of SigmoidOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) of SigmoidOp should not be null."); + ctx.Output("Y")->Resize( ctx.Input("X")->dims()); } diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 239d3d141e1076a0a6a943f340311b17aa6f542a..c67eb028c882ed82ca4e6a4dd70cdea9f69cdc24 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -23,6 +23,11 @@ class SoftmaxOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of SoftmaxOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) of SoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, "The input of softmax op must be a matrix."); ctx.Output("Y")->Resize( diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index ebe5bd352e99e298fb86355730feed77b236d2bd..39f4305877de20d451bc35fe698a0eabf9758d57 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -23,12 +23,18 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of SquaredL2DistanceOp " - "must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Target of SquaredL2DistanceOp " - "must be initialized."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), + "Input(X) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("Y"), + "Input(Y) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("sub_result"), + "Output(sub_result) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of SquaredL2DistanceOp should not be null."); auto* x = ctx.Input("X"); auto x_dims = x->dims(); diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 7170e7256c206d338ef1f6f94d5d1889ca92a3de..41e05c27f9029b2664685d3979fadcfd2bf6dbce 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -22,6 +22,11 @@ class SumOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) of SumOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of SumOp should not be null."); + auto ins = ctx.MultiInput("X"); auto *out = ctx.Output("Out"); int N = ins.size(); diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index ff0e77a344ede7709a805d7dca4397eb49fa300c..169b815feffd86f9ff04c129ccc997230ce03a8c 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -24,7 +24,12 @@ class TopkOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of TopkOP must be initialized."); + "Input(X) of TopkOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of TopkOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Indices"), + "Output(Indices) of TopkOp should not be null."); + auto *input = ctx.Input("X"); const int k = static_cast(ctx.Attr("k")); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index ed7973693619e7765643dda824100afd82616470..184bcbc29c0d26a214345506f126f9cc0d406b07 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -48,6 +48,10 @@ class UniformRandomOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of UniformRandomOp should not be null."); + PADDLE_ENFORCE(Attr("min") < Attr("max"), "uniform_random's min must less then max"); auto* tensor = ctx.Output("Out"); diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 9936fd76baf3e64aed01b8ae1d54e50b39793925..a0533efacdcc0386c0c3ab4691dc74a43435b4e4 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -47,17 +47,24 @@ def set_input(scope, op, inputs, place): if in_name in inputs: if in_dup: sub_in = inputs[in_name] - for sub_in_name, sub_in_array in sub_in: + for sub_in_name, sub_in_val in sub_in: var = scope.find_var(sub_in_name) tensor = var.get_tensor() + sub_in_array = sub_in_val[0] \ + if isinstance(sub_in_val, tuple) else sub_in_val tensor.set_dims(sub_in_array.shape) tensor.set(sub_in_array, place) + if isinstance(sub_in_val, tuple): + tensor.set_lod(sub_in_val[1]) else: var = scope.find_var(in_name) tensor = var.get_tensor() - arr = inputs[in_name] - tensor.set_dims(arr.shape) - tensor.set(arr, place) + in_val = inputs[in_name] + in_array = in_val[0] if isinstance(in_val, tuple) else in_val + tensor.set_dims(in_array.shape) + tensor.set(in_array, place) + if isinstance(in_val, tuple): + tensor.set_lod(in_val[1]) def set_output_grad(scope, op, outputs, place): diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_add_two_op.py rename to python/paddle/v2/framework/tests/test_add_op.py diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index 1f9e4db783c9907a22db72c8a6ff06c7ca0735da..1888ee28f92c66496ce756d8a4a33d3e9ba57d7b 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -4,7 +4,7 @@ from paddle.v2.framework.op import Operator import numpy -class GaussianRandomTest(unittest.TestCase): +class TestGaussianRandomOp(unittest.TestCase): def test_cpu(self): self.gaussian_random_test(place=core.CPUPlace()) diff --git a/python/paddle/v2/framework/tests/test_identity_op.py b/python/paddle/v2/framework/tests/test_identity_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2e95e7c786e3ff99a04b28218ec5b5decf531360 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_identity_op.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestIdentityOp(OpTest): + def setUp(self): + self.op_type = "identity" + self.inputs = {'X': np.random.random((10, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_lookup_table.py rename to python/paddle/v2/framework/tests/test_lookup_table_op.py diff --git a/python/paddle/v2/framework/tests/test_minus_op.py b/python/paddle/v2/framework/tests/test_minus_op.py index dea797a1fea34265d0a32e097f413f421abf2521..c56d7cb548706880dd482bad750f2989c0e9a710 100644 --- a/python/paddle/v2/framework/tests/test_minus_op.py +++ b/python/paddle/v2/framework/tests/test_minus_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class MinusOpTest(OpTest): +class TestMinusOp(OpTest): def setUp(self): self.op_type = "minus" self.inputs = { diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py similarity index 95% rename from python/paddle/v2/framework/tests/test_cross_entropy_op.py rename to python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py index 253e7b8a24465da63a7eacd7983eb831251e6230..fd3cbdb80374865ccf113768856096bf49dce643 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py @@ -3,7 +3,7 @@ import numpy from op_test import OpTest -class TestCrossEntropy(OpTest): +class TestOnehotCrossEntropyOp(OpTest): def setUp(self): self.op_type = "onehot_cross_entropy" batch_size = 30 diff --git a/python/paddle/v2/framework/tests/test_scale_and_identity_op.py b/python/paddle/v2/framework/tests/test_scale_op.py similarity index 56% rename from python/paddle/v2/framework/tests/test_scale_and_identity_op.py rename to python/paddle/v2/framework/tests/test_scale_op.py index 05d76d428299c8176d1a6adf6da15a203fa7502a..2ea1e185470280730ae8c8c0ea9568bbeb43eaf5 100644 --- a/python/paddle/v2/framework/tests/test_scale_and_identity_op.py +++ b/python/paddle/v2/framework/tests/test_scale_op.py @@ -3,20 +3,7 @@ import numpy as np from op_test import OpTest -class IdentityTest(OpTest): - def setUp(self): - self.op_type = "identity" - self.inputs = {'X': np.random.random((10, 10)).astype("float32")} - self.outputs = {'Out': self.inputs['X']} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class ScaleTest(OpTest): +class TestScaleOp(OpTest): def setUp(self): self.op_type = "scale" self.inputs = {'X': np.random.random((10, 10)).astype("float32")} diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..cf864936af6361da1f16df3cfb759b468214b970 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -0,0 +1,51 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSeqAvgPool1D(OpTest): + def setUp(self): + self.op_type = 'sequence_avg_pool' + # one level, batch size is 4 + x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') + lod = [[0, 4, 5, 8, 11]] + + out = np.zeros((4, 23)).astype('float32') + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x.mean(axis=0) + + self.inputs = {'X': (x, lod)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSeqAvgPool2D(OpTest): + def setUp(self): + self.op_type = 'sequence_avg_pool' + # one level, batch size is 4 + x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') + lod = [[0, 4, 5, 8, 13]] + + out = np.zeros((4, 3, 17)).astype('float32') + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) + + self.inputs = {'X': (x, lod)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index 557cf15ace63e336462c7dcdbbc10f30aeedc6f4..64e54d1500c1bc134cc1efe33d41a16dbc08f2d4 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class TestSGD(OpTest): +class TestSGDOp(OpTest): def setUp(self): self.op_type = "sgd" w = np.random.random((102, 105)).astype("float32") diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2316e49eff7bb1cdb53acb3889a6ef05060b59f3..d65d887db4af58c40e4e78fdbfd8e8ee668b7ee3 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class TestSigmoid(OpTest): +class TestSigmoidOp(OpTest): def setUp(self): self.op_type = "sigmoid" self.inputs = { diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/framework/tests/test_top_k_op.py index cab799256d791889c295aa7f9048080f5caaf2dc..694f37d612d4c46e673dc894b05a0a446190732c 100644 --- a/python/paddle/v2/framework/tests/test_top_k_op.py +++ b/python/paddle/v2/framework/tests/test_top_k_op.py @@ -21,6 +21,9 @@ class TestTopkOp(OpTest): self.outputs = {'Out': output, 'Indices': indices} + def test_check_output(self): + self.check_output() + class TestTopkOp3d(OpTest): def setUp(self): @@ -42,6 +45,9 @@ class TestTopkOp3d(OpTest): self.outputs = {'Out': output, 'Indices': indices} + def test_check_output(self): + self.check_output() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index 76a5e36e56ab08230bdc2597d209fcf5d1d2acb0..9e8898fb5920defdfaa361bf45def7666a88beea 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -4,7 +4,7 @@ import paddle.v2.framework.core as core import numpy -class UniformRandomTest(unittest.TestCase): +class TestUniformRandomOp(unittest.TestCase): def test_uniform_random_cpu(self): self.uniform_random_test(place=core.CPUPlace())