提交 2859532d 编写于 作者: Y Yu Yang

Using override keyword in Layer

上级 54ba8eaa
...@@ -155,7 +155,8 @@ protected: ...@@ -155,7 +155,8 @@ protected:
public: public:
explicit BootBiasLayer(const LayerConfig& config) : Layer(config) {} explicit BootBiasLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
if (!Layer::init(layerMap, parameterMap)) return false; if (!Layer::init(layerMap, parameterMap)) return false;
if (biasParameter_) { if (biasParameter_) {
...@@ -174,7 +175,7 @@ public: ...@@ -174,7 +175,7 @@ public:
} }
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
if (biases_) { if (biases_) {
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
outV->addBias(*(biases_->getW()), 1); outV->addBias(*(biases_->getW()), 1);
...@@ -182,7 +183,7 @@ public: ...@@ -182,7 +183,7 @@ public:
} }
} }
virtual void backward(const UpdateCallback& callback) { void backward(const UpdateCallback& callback) override {
if (biases_) { if (biases_) {
backwardActivation(); backwardActivation();
biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
......
...@@ -44,19 +44,20 @@ public: ...@@ -44,19 +44,20 @@ public:
/** /**
* Intialization of AddtoLayer. * Intialization of AddtoLayer.
*/ */
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* Forward propagation. * Forward propagation.
* @note There is no weight matrix for each input, * @note There is no weight matrix for each input,
* because it just a simple add operation. * because it just a simple add operation.
*/ */
void forward(PassType passType); void forward(PassType passType) override;
/** /**
* Backward propagation. * Backward propagation.
*/ */
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,8 @@ public: ...@@ -35,7 +35,8 @@ public:
~AgentLayer() {} ~AgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
// if *numSamples* set, // if *numSamples* set,
// real layer output will only use first *numSamples* rows // real layer output will only use first *numSamples* rows
...@@ -44,8 +45,8 @@ public: ...@@ -44,8 +45,8 @@ public:
numSamples_ = numSamples; numSamples_ = numSamples;
} }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
/** /**
...@@ -56,8 +57,8 @@ public: ...@@ -56,8 +57,8 @@ public:
explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {} explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {}
~SequenceAgentLayer() {} ~SequenceAgentLayer() {}
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
/** /**
...@@ -78,7 +79,8 @@ public: ...@@ -78,7 +79,8 @@ public:
virtual ~GatherAgentLayer() {} virtual ~GatherAgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
// call before addRealLayer // call before addRealLayer
void copyIdAndSequenceInfo(const Argument& input, void copyIdAndSequenceInfo(const Argument& input,
...@@ -88,8 +90,8 @@ public: ...@@ -88,8 +90,8 @@ public:
// add one real layer, can call many times // add one real layer, can call many times
void addRealLayer(LayerPtr layer) { realLayers_.push_back(layer); } void addRealLayer(LayerPtr layer) { realLayers_.push_back(layer); }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
/** /**
...@@ -133,7 +135,8 @@ public: ...@@ -133,7 +135,8 @@ public:
virtual ~ScatterAgentLayer() {} virtual ~ScatterAgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* @brief set real layer in generation * @brief set real layer in generation
...@@ -182,8 +185,8 @@ public: ...@@ -182,8 +185,8 @@ public:
numSequences_ = numSequences; numSequences_ = numSequences;
} }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
/** /**
......
...@@ -38,12 +38,11 @@ public: ...@@ -38,12 +38,11 @@ public:
explicit AverageLayer(const LayerConfig& config) explicit AverageLayer(const LayerConfig& config)
: SequencePoolLayer(config) {} : SequencePoolLayer(config) {}
~AverageLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
protected: protected:
MatrixPtr outMtx_; MatrixPtr outMtx_;
......
...@@ -52,7 +52,8 @@ public: ...@@ -52,7 +52,8 @@ public:
*/ */
static Layer* create(const LayerConfig& config); static Layer* create(const LayerConfig& config);
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* @brief Calculate feature map size. Some input uses frameHeight and * @brief Calculate feature map size. Some input uses frameHeight and
......
...@@ -33,9 +33,10 @@ public: ...@@ -33,9 +33,10 @@ public:
~BatchNormalizationLayer() {} ~BatchNormalizationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
/// Epsilon value used in the batch normalization formula. /// Epsilon value used in the batch normalization formula.
...@@ -58,7 +59,7 @@ protected: ...@@ -58,7 +59,7 @@ protected:
/// to batch, channels* imagePixels. /// to batch, channels* imagePixels.
void shrinkMat(const MatrixPtr& in, MatrixPtr& out); void shrinkMat(const MatrixPtr& in, MatrixPtr& out);
void onPassEnd() { firstTest_ = true; } void onPassEnd() override { firstTest_ = true; }
MatrixPtr tmpMat_, tmpGrad_; MatrixPtr tmpMat_, tmpGrad_;
MatrixPtr expandedIn_, expandedOut_; MatrixPtr expandedIn_, expandedOut_;
......
...@@ -38,9 +38,10 @@ public: ...@@ -38,9 +38,10 @@ public:
virtual ~BilinearInterpLayer() {} virtual ~BilinearInterpLayer() {}
size_t getSize(); size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -58,10 +58,11 @@ public: ...@@ -58,10 +58,11 @@ public:
~BlockExpandLayer() {} ~BlockExpandLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -32,9 +32,10 @@ namespace paddle { ...@@ -32,9 +32,10 @@ namespace paddle {
class CRFDecodingLayer : public CRFLayer { class CRFDecodingLayer : public CRFLayer {
public: public:
explicit CRFDecodingLayer(const LayerConfig& config) : CRFLayer(config) {} explicit CRFDecodingLayer(const LayerConfig& config) : CRFLayer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
std::unique_ptr<LinearChainCRF> crf_; std::unique_ptr<LinearChainCRF> crf_;
......
...@@ -29,9 +29,10 @@ namespace paddle { ...@@ -29,9 +29,10 @@ namespace paddle {
class CRFLayer : public Layer { class CRFLayer : public Layer {
public: public:
explicit CRFLayer(const LayerConfig& config) : Layer(config) {} explicit CRFLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
size_t numClasses_; size_t numClasses_;
......
...@@ -22,10 +22,11 @@ namespace paddle { ...@@ -22,10 +22,11 @@ namespace paddle {
class CTCLayer : public Layer { class CTCLayer : public Layer {
public: public:
explicit CTCLayer(const LayerConfig& config) : Layer(config) {} explicit CTCLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void forwardImp(const Argument& softmaxSeqs, const Argument& labelSeqs); void forwardImp(const Argument& softmaxSeqs, const Argument& labelSeqs);
virtual void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
void backwardImp(const UpdateCallback& callback, void backwardImp(const UpdateCallback& callback,
const Argument& softmaxSeqs, const Argument& softmaxSeqs,
const Argument& labelSeqs); const Argument& labelSeqs);
......
...@@ -28,10 +28,11 @@ public: ...@@ -28,10 +28,11 @@ public:
~ConcatenateLayer() {} ~ConcatenateLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(concat, ConcatenateLayer); REGISTER_LAYER(concat, ConcatenateLayer);
...@@ -101,10 +102,11 @@ public: ...@@ -101,10 +102,11 @@ public:
~ConcatenateLayer2() {} ~ConcatenateLayer2() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
std::vector<std::unique_ptr<Projection>> projections_; std::vector<std::unique_ptr<Projection>> projections_;
......
...@@ -80,7 +80,8 @@ protected: ...@@ -80,7 +80,8 @@ protected:
public: public:
explicit ConvBaseLayer(const LayerConfig& config) : Layer(config) {} explicit ConvBaseLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* imgSizeH_ and imgSizeW_ will be set according to the previous input layers * imgSizeH_ and imgSizeW_ will be set according to the previous input layers
......
...@@ -47,10 +47,11 @@ public: ...@@ -47,10 +47,11 @@ public:
~ConvShiftLayer() {} ~ConvShiftLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(conv_shift, ConvShiftLayer); REGISTER_LAYER(conv_shift, ConvShiftLayer);
......
...@@ -49,10 +49,11 @@ public: ...@@ -49,10 +49,11 @@ public:
~ConvexCombinationLayer() {} ~ConvexCombinationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(convex_comb, ConvexCombinationLayer); REGISTER_LAYER(convex_comb, ConvexCombinationLayer);
......
...@@ -38,10 +38,11 @@ public: ...@@ -38,10 +38,11 @@ public:
~CosSimLayer() {} ~CosSimLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -46,10 +46,11 @@ public: ...@@ -46,10 +46,11 @@ public:
~CosSimVecMatLayer() {} ~CosSimVecMatLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(cos_vm, CosSimVecMatLayer); REGISTER_LAYER(cos_vm, CosSimVecMatLayer);
......
...@@ -367,8 +367,6 @@ void LambdaCost::backward(const UpdateCallback& callback) { ...@@ -367,8 +367,6 @@ void LambdaCost::backward(const UpdateCallback& callback) {
getInputGrad(0)->add(*marginGrad_); getInputGrad(0)->add(*marginGrad_);
} }
void LambdaCost::onPassEnd() {}
void LambdaCost::calcGrad(const real* outputScore, void LambdaCost::calcGrad(const real* outputScore,
const real* score, const real* score,
real* gradData, real* gradData,
...@@ -611,14 +609,15 @@ class SumCostLayer : public Layer { ...@@ -611,14 +609,15 @@ class SumCostLayer : public Layer {
public: public:
explicit SumCostLayer(const LayerConfig& config) : Layer(config) {} explicit SumCostLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
if (!ret) return ret; if (!ret) return ret;
CHECK_EQ(inputLayers_.size(), 1UL); CHECK_EQ(inputLayers_.size(), 1UL);
return true; return true;
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
const MatrixPtr& input = getInputValue(0); const MatrixPtr& input = getInputValue(0);
...@@ -629,7 +628,7 @@ public: ...@@ -629,7 +628,7 @@ public:
output_.value->sumRows(*input, /* scaleSum= */ 1, /* scaleDest= */ 0); output_.value->sumRows(*input, /* scaleSum= */ 1, /* scaleDest= */ 0);
} }
virtual void backward(const UpdateCallback& callback = nullptr) { void backward(const UpdateCallback& callback = nullptr) override {
getInputGrad(0)->add((real)1); getInputGrad(0)->add((real)1);
} }
}; };
......
...@@ -32,15 +32,16 @@ class CostLayer : public Layer { ...@@ -32,15 +32,16 @@ class CostLayer : public Layer {
public: public:
explicit CostLayer(const LayerConfig& config) : Layer(config) {} explicit CostLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; } LayerPtr getOutputLayer() { return inputLayers_[0]; }
LayerPtr getLabelLayer() { return inputLayers_[1]; } LayerPtr getLabelLayer() { return inputLayers_[1]; }
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
virtual void forwardImp(Matrix& outputValue, virtual void forwardImp(Matrix& outputValue,
Argument& label, Argument& label,
...@@ -68,11 +69,14 @@ public: ...@@ -68,11 +69,14 @@ public:
explicit MultiClassCrossEntropy(const LayerConfig& config) explicit MultiClassCrossEntropy(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
}; };
/** /**
...@@ -95,11 +99,14 @@ public: ...@@ -95,11 +99,14 @@ public:
explicit MultiClassCrossEntropyWithSelfNorm(const LayerConfig& config) explicit MultiClassCrossEntropyWithSelfNorm(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
protected: protected:
MatrixPtr sftMaxSum_; MatrixPtr sftMaxSum_;
...@@ -117,11 +124,14 @@ public: ...@@ -117,11 +124,14 @@ public:
explicit SoftBinaryClassCrossEntropy(const LayerConfig& config) explicit SoftBinaryClassCrossEntropy(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
protected: protected:
MatrixPtr targetPerDim_; MatrixPtr targetPerDim_;
...@@ -139,11 +149,14 @@ public: ...@@ -139,11 +149,14 @@ public:
explicit SumOfSquaresCostLayer(const LayerConfig& config) explicit SumOfSquaresCostLayer(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
}; };
/** /**
...@@ -162,17 +175,18 @@ class RankingCost : public Layer { ...@@ -162,17 +175,18 @@ class RankingCost : public Layer {
public: public:
explicit RankingCost(const LayerConfig& config) : Layer(config) {} explicit RankingCost(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer(size_t i) { return inputLayers_[i]; } LayerPtr getOutputLayer(size_t i) { return inputLayers_[i]; }
LayerPtr getLabelLayer() { return inputLayers_[2]; } LayerPtr getLabelLayer() { return inputLayers_[2]; }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
void onPassEnd(); void onPassEnd() override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost) { void forwardImp(Matrix& output, Argument& label, Matrix& cost) {
(void)output; (void)output;
...@@ -214,17 +228,16 @@ class LambdaCost : public Layer { ...@@ -214,17 +228,16 @@ class LambdaCost : public Layer {
public: public:
explicit LambdaCost(const LayerConfig& config) : Layer(config) {} explicit LambdaCost(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; } LayerPtr getOutputLayer() { return inputLayers_[0]; }
LayerPtr getScoreLayer() { return inputLayers_[1]; } LayerPtr getScoreLayer() { return inputLayers_[1]; }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
void onPassEnd();
real calcNDCG(const real* outputScore, const real* score, int size); real calcNDCG(const real* outputScore, const real* score, int size);
void calcGrad(const real* outputScore, void calcGrad(const real* outputScore,
...@@ -256,11 +269,14 @@ public: ...@@ -256,11 +269,14 @@ public:
explicit MultiBinaryLabelCrossEntropy(const LayerConfig& config) explicit MultiBinaryLabelCrossEntropy(const LayerConfig& config)
: CostLayer(config) {} : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
}; };
/** /**
...@@ -282,13 +298,16 @@ class HuberTwoClass : public CostLayer { ...@@ -282,13 +298,16 @@ class HuberTwoClass : public CostLayer {
public: public:
explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {} explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost); void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void forwardImpIn(Matrix& output, Argument& label, Matrix& cost); void forwardImpIn(Matrix& output, Argument& label, Matrix& cost);
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad); void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad);
}; };
......
...@@ -35,14 +35,15 @@ public: ...@@ -35,14 +35,15 @@ public:
~CudnnBatchNormLayer(); ~CudnnBatchNormLayer();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* reshape tensor of ioDesc_. * reshape tensor of ioDesc_.
*/ */
void reshape(int batchSize); void reshape(int batchSize);
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
/** /**
......
...@@ -45,9 +45,10 @@ public: ...@@ -45,9 +45,10 @@ public:
~CudnnConvLayer(); ~CudnnConvLayer();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
void addBiases(); void addBiases();
void bpropBiases(); void bpropBiases();
}; };
......
...@@ -45,7 +45,8 @@ public: ...@@ -45,7 +45,8 @@ public:
hl_pooling_mode_t* mode = nullptr); hl_pooling_mode_t* mode = nullptr);
explicit CudnnPoolLayer(const LayerConfig& config); explicit CudnnPoolLayer(const LayerConfig& config);
~CudnnPoolLayer(); ~CudnnPoolLayer();
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/** /**
* Reshape input and output tensor descriptor. * Reshape input and output tensor descriptor.
...@@ -53,8 +54,8 @@ public: ...@@ -53,8 +54,8 @@ public:
* So reshaping is needed. * So reshaping is needed.
*/ */
void reshape(int batchSize); void reshape(int batchSize);
virtual void forward(PassType passType); virtual void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); virtual void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -33,13 +33,13 @@ public: ...@@ -33,13 +33,13 @@ public:
/** /**
* Prefetch sparse matrix/ids only. * Prefetch sparse matrix/ids only.
*/ */
void prefetch() { output_ = data_; } void prefetch() override { output_ = data_; }
/** /**
* Forward propagation. Copy data_ (value, in, grad, ids, cpuSequenceDims, * Forward propagation. Copy data_ (value, in, grad, ids, cpuSequenceDims,
* sequenceStartPositions, subSequenceStartPositions, strs) to output_. * sequenceStartPositions, subSequenceStartPositions, strs) to output_.
*/ */
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
copyDataToOutput(output_); copyDataToOutput(output_);
if (FLAGS_show_layer_stat) { if (FLAGS_show_layer_stat) {
...@@ -50,9 +50,9 @@ public: ...@@ -50,9 +50,9 @@ public:
/** /**
* Data layer's backward propagation do nothing. * Data layer's backward propagation do nothing.
*/ */
virtual void backward(const UpdateCallback& callback) { (void)callback; } void backward(const UpdateCallback& callback) override { (void)callback; }
virtual void copyOutputToOtherDevice() { void copyOutputToOtherDevice() override {
for (size_t i = 0; i != outputOtherDevice_.size(); i++) { for (size_t i = 0; i != outputOtherDevice_.size(); i++) {
copyDataToOutput(outputOtherDevice_[i]); copyDataToOutput(outputOtherDevice_[i]);
} }
......
...@@ -44,10 +44,11 @@ public: ...@@ -44,10 +44,11 @@ public:
~DataNormLayer() {} ~DataNormLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
int mode_; int mode_;
......
...@@ -27,14 +27,14 @@ class EosIdCheckLayer : public Layer { ...@@ -27,14 +27,14 @@ class EosIdCheckLayer : public Layer {
public: public:
explicit EosIdCheckLayer(const LayerConfig& config) : Layer(config) {} explicit EosIdCheckLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size()); CHECK_EQ(1UL, inputLayers_.size());
return ret; return ret;
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
const Argument& input = getInput(0); const Argument& input = getInput(0);
...@@ -42,7 +42,7 @@ public: ...@@ -42,7 +42,7 @@ public:
output_.ids->isEqualTo(*input.ids, config_.eos_id()); output_.ids->isEqualTo(*input.ids, config_.eos_id());
} }
virtual void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
REGISTER_LAYER(eos_id, EosIdCheckLayer); REGISTER_LAYER(eos_id, EosIdCheckLayer);
......
...@@ -48,7 +48,8 @@ public: ...@@ -48,7 +48,8 @@ public:
~ExpandConvBaseLayer() {} ~ExpandConvBaseLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
size_t getOutputSize(); size_t getOutputSize();
/** /**
......
...@@ -35,10 +35,11 @@ public: ...@@ -35,10 +35,11 @@ public:
~ExpandConvLayer() {} ~ExpandConvLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -34,10 +34,11 @@ public: ...@@ -34,10 +34,11 @@ public:
~ExpandConvTransLayer() {} ~ExpandConvTransLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -53,10 +53,11 @@ public: ...@@ -53,10 +53,11 @@ public:
~ExpandLayer() {} ~ExpandLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -46,10 +46,11 @@ public: ...@@ -46,10 +46,11 @@ public:
~FeatureMapExpandLayer() {} ~FeatureMapExpandLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(featmap_expand, FeatureMapExpandLayer); REGISTER_LAYER(featmap_expand, FeatureMapExpandLayer);
......
...@@ -36,13 +36,14 @@ public: ...@@ -36,13 +36,14 @@ public:
explicit FullyConnectedLayer(const LayerConfig& config) : Layer(config) {} explicit FullyConnectedLayer(const LayerConfig& config) : Layer(config) {}
~FullyConnectedLayer() {} ~FullyConnectedLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
void prefetch(); void prefetch() override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -50,17 +50,18 @@ class GatedRecurrentLayer : public Layer, public GruCompute { ...@@ -50,17 +50,18 @@ class GatedRecurrentLayer : public Layer, public GruCompute {
public: public:
explicit GatedRecurrentLayer(const LayerConfig& config) : Layer(config) {} explicit GatedRecurrentLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
void resetState(); void resetState() override;
void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
void forwardSequence(int batchSize, void forwardSequence(int batchSize,
......
...@@ -22,17 +22,18 @@ public: ...@@ -22,17 +22,18 @@ public:
~GetOutputLayer() {} ~GetOutputLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
if (!Layer::init(layerMap, parameterMap)) return false; if (!Layer::init(layerMap, parameterMap)) return false;
CHECK_EQ(1U, inputLayers_.size()); CHECK_EQ(1U, inputLayers_.size());
CHECK_NE(inputArgument_[0], ""); CHECK_NE(inputArgument_[0], "");
return true; return true;
} }
void forward(PassType passType) { void forward(PassType passType) override {
output_ = getPrev(0)->getOutput(inputArgument_[0]); output_ = getPrev(0)->getOutput(inputArgument_[0]);
} }
void backward(const UpdateCallback& callback = nullptr) {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
REGISTER_LAYER(get_output, GetOutputLayer); REGISTER_LAYER(get_output, GetOutputLayer);
......
...@@ -55,10 +55,11 @@ public: ...@@ -55,10 +55,11 @@ public:
~GruStepLayer() {} ~GruStepLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(gru_step, GruStepLayer); REGISTER_LAYER(gru_step, GruStepLayer);
......
...@@ -61,9 +61,10 @@ class HierarchicalSigmoidLayer : public Layer { ...@@ -61,9 +61,10 @@ class HierarchicalSigmoidLayer : public Layer {
public: public:
explicit HierarchicalSigmoidLayer(const LayerConfig& config) explicit HierarchicalSigmoidLayer(const LayerConfig& config)
: Layer(config) {} : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
/** /**
......
...@@ -43,10 +43,11 @@ public: ...@@ -43,10 +43,11 @@ public:
~InterpolationLayer() {} ~InterpolationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(interpolation, InterpolationLayer); REGISTER_LAYER(interpolation, InterpolationLayer);
......
...@@ -74,17 +74,18 @@ class LstmLayer : public Layer, public LstmCompute { ...@@ -74,17 +74,18 @@ class LstmLayer : public Layer, public LstmCompute {
public: public:
explicit LstmLayer(const LayerConfig &config) : Layer(config) {} explicit LstmLayer(const LayerConfig &config) : Layer(config) {}
bool init(const LayerMap &layerMap, const ParameterMap &parameterMap); bool init(const LayerMap &layerMap,
const ParameterMap &parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback &callback); void backward(const UpdateCallback &callback) override;
void resetState(); void resetState() override;
void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
/** /**
......
...@@ -35,10 +35,11 @@ public: ...@@ -35,10 +35,11 @@ public:
~LstmStepLayer() {} ~LstmStepLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(lstm_step, LstmStepLayer); REGISTER_LAYER(lstm_step, LstmStepLayer);
......
...@@ -181,11 +181,12 @@ class MDLstmLayer : public LstmLayer { ...@@ -181,11 +181,12 @@ class MDLstmLayer : public LstmLayer {
public: public:
explicit MDLstmLayer(const LayerConfig& config) : LstmLayer(config) {} explicit MDLstmLayer(const LayerConfig& config) : LstmLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
protected: protected:
void forwardOneSequence(int start, CoordIterator& coordIter); void forwardOneSequence(int start, CoordIterator& coordIter);
......
...@@ -30,8 +30,8 @@ private: ...@@ -30,8 +30,8 @@ private:
public: public:
explicit MaxIdLayer(const LayerConfig& config) : Layer(config) {} explicit MaxIdLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size()); CHECK_EQ(1UL, inputLayers_.size());
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
return ret; return ret;
} }
virtual void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
const Argument& input = getInput(0); const Argument& input = getInput(0);
size_t batchSize = input.getBatchSize(); size_t batchSize = input.getBatchSize();
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
input.value->rowMax(*output_.ids, *output_.in); input.value->rowMax(*output_.ids, *output_.in);
} }
virtual void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
REGISTER_LAYER(maxid, MaxIdLayer); REGISTER_LAYER(maxid, MaxIdLayer);
......
...@@ -42,14 +42,13 @@ protected: ...@@ -42,14 +42,13 @@ protected:
public: public:
explicit MaxLayer(const LayerConfig& config) : SequencePoolLayer(config) {} explicit MaxLayer(const LayerConfig& config) : SequencePoolLayer(config) {}
~MaxLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
return SequencePoolLayer::init(layerMap, parameterMap); return SequencePoolLayer::init(layerMap, parameterMap);
} }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -45,10 +45,11 @@ public: ...@@ -45,10 +45,11 @@ public:
explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {} explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {}
virtual ~MaxOutLayer() {} virtual ~MaxOutLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -35,21 +35,22 @@ public: ...@@ -35,21 +35,22 @@ public:
~MixedLayer() {} ~MixedLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void prefetch(); void prefetch() override;
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
virtual void resetState(); void resetState() override;
/** /**
* setState() should be called after getState(). * setState() should be called after getState().
* Argument state consists of all projections states. * Argument state consists of all projections states.
*/ */
virtual void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
/** /**
* Return state which consists of all projections states. * Return state which consists of all projections states.
*/ */
virtual LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
std::vector<std::unique_ptr<Projection>> projections_; std::vector<std::unique_ptr<Projection>> projections_;
......
...@@ -69,10 +69,11 @@ public: ...@@ -69,10 +69,11 @@ public:
~MultiplexLayer() {} ~MultiplexLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
private: private:
/** /**
......
...@@ -61,7 +61,8 @@ public: ...@@ -61,7 +61,8 @@ public:
rand_(0, config.num_classes() - 1), rand_(0, config.num_classes() - 1),
prepared_(false) {} prepared_(false) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
...@@ -146,7 +147,7 @@ public: ...@@ -146,7 +147,7 @@ public:
prepared_ = true; prepared_ = true;
} }
void prefetch() { void prefetch() override {
prepareSamples(); prepareSamples();
IVector::resizeOrCreate(labelIds_, samples_.size(), useGpu_); IVector::resizeOrCreate(labelIds_, samples_.size(), useGpu_);
int* ids = labelIds_->getData(); int* ids = labelIds_->getData();
...@@ -163,7 +164,7 @@ public: ...@@ -163,7 +164,7 @@ public:
} }
} }
void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
CHECK(!useGpu_) << "GPU is not supported"; CHECK(!useGpu_) << "GPU is not supported";
...@@ -199,7 +200,7 @@ public: ...@@ -199,7 +200,7 @@ public:
forwardCost(); forwardCost();
} }
void backward(const UpdateCallback& callback) { void backward(const UpdateCallback& callback) override {
Matrix::resizeOrCreate(sampleOut_.grad, Matrix::resizeOrCreate(sampleOut_.grad,
1, 1,
samples_.size(), samples_.size(),
......
...@@ -30,7 +30,8 @@ class NormLayer : public Layer { ...@@ -30,7 +30,8 @@ class NormLayer : public Layer {
public: public:
explicit NormLayer(const LayerConfig& config) : Layer(config) {} explicit NormLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
return true; return true;
} }
...@@ -56,9 +57,10 @@ protected: ...@@ -56,9 +57,10 @@ protected:
public: public:
explicit ResponseNormLayer(const LayerConfig& config) : NormLayer(config) {} explicit ResponseNormLayer(const LayerConfig& config) : NormLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType) { LOG(FATAL) << "Not implemented"; } const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr) { void forward(PassType passType) override { LOG(FATAL) << "Not implemented"; }
void backward(const UpdateCallback& callback = nullptr) override {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
}; };
......
...@@ -36,9 +36,10 @@ public: ...@@ -36,9 +36,10 @@ public:
size_t getSize(); size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
TensorShape shape_; TensorShape shape_;
......
...@@ -38,10 +38,11 @@ public: ...@@ -38,10 +38,11 @@ public:
~OuterProdLayer() {} ~OuterProdLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(out_prod, OuterProdLayer); REGISTER_LAYER(out_prod, OuterProdLayer);
......
...@@ -29,9 +29,10 @@ public: ...@@ -29,9 +29,10 @@ public:
~PadLayer() {} ~PadLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
void forward(PassType passType); const ParameterMap& parameterMap) override;
void backward(const UpdateCallback& callback = nullptr); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected: protected:
void setOutDims(const size_t batchSize); void setOutDims(const size_t batchSize);
......
...@@ -56,9 +56,10 @@ public: ...@@ -56,9 +56,10 @@ public:
~ParameterReluLayer() {} ~ParameterReluLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -46,7 +46,8 @@ public: ...@@ -46,7 +46,8 @@ public:
*/ */
static Layer* create(const LayerConfig& config); static Layer* create(const LayerConfig& config);
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
size_t getSize(); size_t getSize();
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -40,10 +40,11 @@ public: ...@@ -40,10 +40,11 @@ public:
~PowerLayer() {} ~PowerLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(power, PowerLayer); REGISTER_LAYER(power, PowerLayer);
......
...@@ -19,8 +19,8 @@ namespace paddle { ...@@ -19,8 +19,8 @@ namespace paddle {
class PrintLayer : public Layer { class PrintLayer : public Layer {
public: public:
explicit PrintLayer(const LayerConfig& config) : Layer(config) {} explicit PrintLayer(const LayerConfig& config) : Layer(config) {}
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
void PrintLayer::forward(PassType passType) { void PrintLayer::forward(PassType passType) {
......
...@@ -30,10 +30,11 @@ namespace paddle { ...@@ -30,10 +30,11 @@ namespace paddle {
class PriorBoxLayer : public Layer { class PriorBoxLayer : public Layer {
public: public:
explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {} explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
protected: protected:
int numPriors_; int numPriors_;
......
...@@ -45,17 +45,18 @@ class RecurrentLayer : public Layer { ...@@ -45,17 +45,18 @@ class RecurrentLayer : public Layer {
public: public:
explicit RecurrentLayer(const LayerConfig& config) : Layer(config) {} explicit RecurrentLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
void resetState(); void resetState() override;
void setState(LayerStatePtr state); void setState(LayerStatePtr state) override;
LayerStatePtr getState(); LayerStatePtr getState() override;
protected: protected:
/** /**
......
...@@ -33,15 +33,15 @@ public: ...@@ -33,15 +33,15 @@ public:
void initSubNetwork(NeuralNetwork* rootNetwork, void initSubNetwork(NeuralNetwork* rootNetwork,
const ModelConfig& config, const ModelConfig& config,
const std::vector<ParameterType>& parameterTypes, const std::vector<ParameterType>& parameterTypes,
bool useGpu); bool useGpu) override;
void forward(PassType passType) { void forward(PassType passType) override {
REGISTER_TIMER_INFO("RecurrentGroupFwTime", getName().c_str()); REGISTER_TIMER_INFO("RecurrentGroupFwTime", getName().c_str());
const std::vector<Argument> inArgs; const std::vector<Argument> inArgs;
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
network_->forward(inArgs, &outArgs, passType); network_->forward(inArgs, &outArgs, passType);
} }
void backward(const UpdateCallback& callback) { void backward(const UpdateCallback& callback) override {
REGISTER_TIMER_INFO("RecurrentGroupBwTime", getName().c_str()); REGISTER_TIMER_INFO("RecurrentGroupBwTime", getName().c_str());
network_->backward(nullptr); network_->backward(nullptr);
...@@ -53,7 +53,8 @@ public: ...@@ -53,7 +53,8 @@ public:
/** /**
* @see Layer.accessSubNetwork * @see Layer.accessSubNetwork
*/ */
void accessSubNetwork(const std::function<void(NeuralNetwork&)>& callback) { void accessSubNetwork(
const std::function<void(NeuralNetwork&)>& callback) override {
callback(*network_); callback(*network_);
} }
......
...@@ -20,18 +20,19 @@ namespace paddle { ...@@ -20,18 +20,19 @@ namespace paddle {
/** /**
* @brief A layer for resizing a minibatch matrix h*w to h'*w' * @brief A layer for resizing a minibatch matrix h*w to h'*w'
* @note * @note
* origin matrix height * witdth) * origin matrix height * width)
* resize matrix: (height * width / size) * size * resize matrix: (height * width / size) * size
*/ */
class ResizeLayer : public Layer { class ResizeLayer : public Layer {
public: public:
explicit ResizeLayer(const LayerConfig& config) : Layer(config) {} explicit ResizeLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback) override;
}; };
REGISTER_LAYER(resize, ResizeLayer); REGISTER_LAYER(resize, ResizeLayer);
......
...@@ -35,8 +35,8 @@ public: ...@@ -35,8 +35,8 @@ public:
explicit SamplingIdLayer(const LayerConfig& config) explicit SamplingIdLayer(const LayerConfig& config)
: Layer(config), rand1_(0, 1) {} : Layer(config), rand1_(0, 1) {}
virtual bool init(const LayerMap& layerMap, bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap); bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size()); CHECK_EQ(1UL, inputLayers_.size());
if (useGpu_) { if (useGpu_) {
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
return ret; return ret;
} }
void forward(PassType passType) { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
if (useGpu_) { if (useGpu_) {
for (size_t i = 0; i < inputLayers_.size(); i++) { for (size_t i = 0; i < inputLayers_.size(); i++) {
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
output_.ids->copyFrom(ids.data(), batchSize); output_.ids->copyFrom(ids.data(), batchSize);
} }
virtual void backward(const UpdateCallback& callback) {} void backward(const UpdateCallback& callback) override {}
}; };
REGISTER_LAYER(sampling_id, SamplingIdLayer); REGISTER_LAYER(sampling_id, SamplingIdLayer);
......
...@@ -37,10 +37,11 @@ public: ...@@ -37,10 +37,11 @@ public:
~ScalingLayer() {} ~ScalingLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(scaling, ScalingLayer); REGISTER_LAYER(scaling, ScalingLayer);
......
...@@ -65,9 +65,10 @@ public: ...@@ -65,9 +65,10 @@ public:
: Layer(config), selCols_(nullptr) {} : Layer(config), selCols_(nullptr) {}
~SelectiveFullyConnectedLayer() {} ~SelectiveFullyConnectedLayer() {}
void prefetch(); void prefetch() override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
...@@ -90,8 +91,8 @@ public: ...@@ -90,8 +91,8 @@ public:
void fillSelectiveData( void fillSelectiveData(
const std::shared_ptr<std::vector<std::pair<int*, size_t>>>& candidates); const std::shared_ptr<std::vector<std::pair<int*, size_t>>>& candidates);
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
private: private:
/** /**
......
...@@ -35,10 +35,11 @@ public: ...@@ -35,10 +35,11 @@ public:
~SequenceConcatLayer() {} ~SequenceConcatLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
REGISTER_LAYER(seqconcat, SequenceConcatLayer); REGISTER_LAYER(seqconcat, SequenceConcatLayer);
......
...@@ -42,12 +42,11 @@ public: ...@@ -42,12 +42,11 @@ public:
explicit SequenceLastInstanceLayer(const LayerConfig& config) explicit SequenceLastInstanceLayer(const LayerConfig& config)
: SequencePoolLayer(config) {} : SequencePoolLayer(config) {}
~SequenceLastInstanceLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer); REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer);
......
...@@ -46,12 +46,11 @@ protected: ...@@ -46,12 +46,11 @@ protected:
public: public:
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {} explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}
virtual ~SequencePoolLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
} // namespace paddle } // namespace paddle
...@@ -34,12 +34,11 @@ protected: ...@@ -34,12 +34,11 @@ protected:
public: public:
explicit SequenceReshapeLayer(const LayerConfig& config) : Layer(config) {} explicit SequenceReshapeLayer(const LayerConfig& config) : Layer(config) {}
~SequenceReshapeLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(seqreshape, SequenceReshapeLayer); REGISTER_LAYER(seqreshape, SequenceReshapeLayer);
......
...@@ -39,12 +39,11 @@ class SlopeInterceptLayer : public Layer { ...@@ -39,12 +39,11 @@ class SlopeInterceptLayer : public Layer {
public: public:
explicit SlopeInterceptLayer(const LayerConfig& config) : Layer(config) {} explicit SlopeInterceptLayer(const LayerConfig& config) : Layer(config) {}
~SlopeInterceptLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(slope_intercept, SlopeInterceptLayer); REGISTER_LAYER(slope_intercept, SlopeInterceptLayer);
......
...@@ -43,9 +43,8 @@ protected: ...@@ -43,9 +43,8 @@ protected:
public: public:
explicit SpatialPyramidPoolLayer(const LayerConfig& config) : Layer(config) {} explicit SpatialPyramidPoolLayer(const LayerConfig& config) : Layer(config) {}
~SpatialPyramidPoolLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
ProjectionConfig getConfig(size_t sizeX_, ProjectionConfig getConfig(size_t sizeX_,
size_t sizeY_, size_t sizeY_,
...@@ -54,7 +53,7 @@ public: ...@@ -54,7 +53,7 @@ public:
std::string& poolType_); std::string& poolType_);
size_t getSize(); size_t getSize();
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -35,12 +35,11 @@ protected: ...@@ -35,12 +35,11 @@ protected:
public: public:
explicit SubSequenceLayer(const LayerConfig& config) : Layer(config) {} explicit SubSequenceLayer(const LayerConfig& config) : Layer(config) {}
~SubSequenceLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(subseq, SubSequenceLayer); REGISTER_LAYER(subseq, SubSequenceLayer);
......
...@@ -41,12 +41,11 @@ protected: ...@@ -41,12 +41,11 @@ protected:
public: public:
explicit SumToOneNormLayer(const LayerConfig& config) : Layer(config) {} explicit SumToOneNormLayer(const LayerConfig& config) : Layer(config) {}
~SumToOneNormLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
}; };
REGISTER_LAYER(sum_to_one_norm, SumToOneNormLayer); REGISTER_LAYER(sum_to_one_norm, SumToOneNormLayer);
......
...@@ -44,13 +44,12 @@ protected: ...@@ -44,13 +44,12 @@ protected:
public: public:
explicit TensorLayer(const LayerConfig& config) : Layer(config) {} explicit TensorLayer(const LayerConfig& config) : Layer(config) {}
~TensorLayer() {} bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
Weight& getWeight(int idx) { return *weights_[idx]; } Weight& getWeight(int idx) { return *weights_[idx]; }
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -32,9 +32,10 @@ class TransLayer : public Layer { ...@@ -32,9 +32,10 @@ class TransLayer : public Layer {
public: public:
explicit TransLayer(const LayerConfig& config) : Layer(config) {} explicit TransLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType); void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
}; };
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,8 @@ class ValidationLayer : public Layer { ...@@ -26,7 +26,8 @@ class ValidationLayer : public Layer {
public: public:
explicit ValidationLayer(const LayerConfig& config) : Layer(config) {} explicit ValidationLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; } LayerPtr getOutputLayer() { return inputLayers_[0]; }
...@@ -37,13 +38,13 @@ public: ...@@ -37,13 +38,13 @@ public:
return inputLayers_[2]; return inputLayers_[2];
} }
virtual void forward(PassType passType); void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr) override;
virtual void validationImp(MatrixPtr outputValue, IVectorPtr label) = 0; virtual void validationImp(MatrixPtr outputValue, IVectorPtr label) = 0;
virtual void onPassEnd() = 0; void onPassEnd() override = 0;
}; };
/* /*
...@@ -57,11 +58,12 @@ public: ...@@ -57,11 +58,12 @@ public:
cpuLabel_(nullptr), cpuLabel_(nullptr),
cpuWeight_(nullptr) {} cpuWeight_(nullptr) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void validationImp(MatrixPtr outputValue, IVectorPtr label); void validationImp(MatrixPtr outputValue, IVectorPtr label) override;
void onPassEnd(); void onPassEnd() override;
struct PredictionResult { struct PredictionResult {
PredictionResult(real __out, int __label) : out(__out), label(__label) {} PredictionResult(real __out, int __label) : out(__out), label(__label) {}
...@@ -86,11 +88,12 @@ public: ...@@ -86,11 +88,12 @@ public:
explicit PnpairValidation(const LayerConfig& config) explicit PnpairValidation(const LayerConfig& config)
: ValidationLayer(config) {} : ValidationLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void validationImp(MatrixPtr outputValue, IVectorPtr label); void validationImp(MatrixPtr outputValue, IVectorPtr label) override;
void onPassEnd(); void onPassEnd() override;
private: private:
bool passBegin_; bool passBegin_;
......
...@@ -30,9 +30,10 @@ public: ...@@ -30,9 +30,10 @@ public:
explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {} explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {}
~WarpCTCLayer() {} ~WarpCTCLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap,
virtual void forward(PassType passType); const ParameterMap& parameterMap) override;
virtual void backward(const UpdateCallback& callback); void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected: protected:
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册