提交 ed808f5e 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1251 from reyoung/feature/add_override_to_layer_init

Using override keyword in Layer
......@@ -155,7 +155,8 @@ protected:
public:
explicit BootBiasLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
if (!Layer::init(layerMap, parameterMap)) return false;
if (biasParameter_) {
......@@ -174,7 +175,7 @@ public:
}
}
virtual void forward(PassType passType) {
void forward(PassType passType) override {
if (biases_) {
MatrixPtr outV = getOutputValue();
outV->addBias(*(biases_->getW()), 1);
......@@ -182,7 +183,7 @@ public:
}
}
virtual void backward(const UpdateCallback& callback) {
void backward(const UpdateCallback& callback) override {
if (biases_) {
backwardActivation();
biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
......
......@@ -44,19 +44,20 @@ public:
/**
* Intialization of AddtoLayer.
*/
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/**
* Forward propagation.
* @note There is no weight matrix for each input,
* because it just a simple add operation.
*/
void forward(PassType passType);
void forward(PassType passType) override;
/**
* Backward propagation.
*/
void backward(const UpdateCallback& callback = nullptr);
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -35,7 +35,8 @@ public:
~AgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
// if *numSamples* set,
// real layer output will only use first *numSamples* rows
......@@ -44,8 +45,8 @@ public:
numSamples_ = numSamples;
}
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr) {}
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override {}
};
/**
......@@ -56,8 +57,8 @@ public:
explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {}
~SequenceAgentLayer() {}
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr) {}
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override {}
};
/**
......@@ -78,7 +79,8 @@ public:
virtual ~GatherAgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
// call before addRealLayer
void copyIdAndSequenceInfo(const Argument& input,
......@@ -88,8 +90,8 @@ public:
// add one real layer, can call many times
void addRealLayer(LayerPtr layer) { realLayers_.push_back(layer); }
void forward(PassType passType);
void backward(const UpdateCallback& callback);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
};
/**
......@@ -133,7 +135,8 @@ public:
virtual ~ScatterAgentLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/**
* @brief set real layer in generation
......@@ -182,8 +185,8 @@ public:
numSequences_ = numSequences;
}
void forward(PassType passType);
void backward(const UpdateCallback& callback);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
};
/**
......
......@@ -38,12 +38,11 @@ public:
explicit AverageLayer(const LayerConfig& config)
: SequencePoolLayer(config) {}
~AverageLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
MatrixPtr outMtx_;
......
......@@ -52,7 +52,8 @@ public:
*/
static Layer* create(const LayerConfig& config);
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/**
* @brief Calculate feature map size. Some input uses frameHeight and
......
......@@ -33,9 +33,10 @@ public:
~BatchNormalizationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
/// Epsilon value used in the batch normalization formula.
......@@ -58,7 +59,7 @@ protected:
/// to batch, channels* imagePixels.
void shrinkMat(const MatrixPtr& in, MatrixPtr& out);
void onPassEnd() { firstTest_ = true; }
void onPassEnd() override { firstTest_ = true; }
MatrixPtr tmpMat_, tmpGrad_;
MatrixPtr expandedIn_, expandedOut_;
......
......@@ -38,9 +38,10 @@ public:
virtual ~BilinearInterpLayer() {}
size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -58,10 +58,11 @@ public:
~BlockExpandLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -32,9 +32,10 @@ namespace paddle {
class CRFDecodingLayer : public CRFLayer {
public:
explicit CRFDecodingLayer(const LayerConfig& config) : CRFLayer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected:
std::unique_ptr<LinearChainCRF> crf_;
......
......@@ -29,9 +29,10 @@ namespace paddle {
class CRFLayer : public Layer {
public:
explicit CRFLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected:
size_t numClasses_;
......
......@@ -22,10 +22,11 @@ namespace paddle {
class CTCLayer : public Layer {
public:
explicit CTCLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
virtual void forward(PassType passType);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void forwardImp(const Argument& softmaxSeqs, const Argument& labelSeqs);
virtual void backward(const UpdateCallback& callback);
void backward(const UpdateCallback& callback) override;
void backwardImp(const UpdateCallback& callback,
const Argument& softmaxSeqs,
const Argument& labelSeqs);
......
......@@ -28,10 +28,11 @@ public:
~ConcatenateLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(concat, ConcatenateLayer);
......@@ -101,10 +102,11 @@ public:
~ConcatenateLayer2() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
std::vector<std::unique_ptr<Projection>> projections_;
......
......@@ -80,7 +80,8 @@ protected:
public:
explicit ConvBaseLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/**
* imgSizeH_ and imgSizeW_ will be set according to the previous input layers
......
......@@ -47,10 +47,11 @@ public:
~ConvShiftLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(conv_shift, ConvShiftLayer);
......
......@@ -49,10 +49,11 @@ public:
~ConvexCombinationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(convex_comb, ConvexCombinationLayer);
......
......@@ -38,10 +38,11 @@ public:
~CosSimLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -46,10 +46,11 @@ public:
~CosSimVecMatLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(cos_vm, CosSimVecMatLayer);
......
......@@ -367,8 +367,6 @@ void LambdaCost::backward(const UpdateCallback& callback) {
getInputGrad(0)->add(*marginGrad_);
}
void LambdaCost::onPassEnd() {}
void LambdaCost::calcGrad(const real* outputScore,
const real* score,
real* gradData,
......@@ -611,14 +609,15 @@ class SumCostLayer : public Layer {
public:
explicit SumCostLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap);
if (!ret) return ret;
CHECK_EQ(inputLayers_.size(), 1UL);
return true;
}
virtual void forward(PassType passType) {
void forward(PassType passType) override {
Layer::forward(passType);
const MatrixPtr& input = getInputValue(0);
......@@ -629,7 +628,7 @@ public:
output_.value->sumRows(*input, /* scaleSum= */ 1, /* scaleDest= */ 0);
}
virtual void backward(const UpdateCallback& callback = nullptr) {
void backward(const UpdateCallback& callback = nullptr) override {
getInputGrad(0)->add((real)1);
}
};
......
......@@ -32,15 +32,16 @@ class CostLayer : public Layer {
public:
explicit CostLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; }
LayerPtr getLabelLayer() { return inputLayers_[1]; }
virtual void forward(PassType passType);
void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr);
void backward(const UpdateCallback& callback = nullptr) override;
virtual void forwardImp(Matrix& outputValue,
Argument& label,
......@@ -68,11 +69,14 @@ public:
explicit MultiClassCrossEntropy(const LayerConfig& config)
: CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost);
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad);
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
};
/**
......@@ -95,11 +99,14 @@ public:
explicit MultiClassCrossEntropyWithSelfNorm(const LayerConfig& config)
: CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost);
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad);
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
protected:
MatrixPtr sftMaxSum_;
......@@ -117,11 +124,14 @@ public:
explicit SoftBinaryClassCrossEntropy(const LayerConfig& config)
: CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost);
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad);
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
protected:
MatrixPtr targetPerDim_;
......@@ -139,11 +149,14 @@ public:
explicit SumOfSquaresCostLayer(const LayerConfig& config)
: CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost);
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad);
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
};
/**
......@@ -162,17 +175,18 @@ class RankingCost : public Layer {
public:
explicit RankingCost(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer(size_t i) { return inputLayers_[i]; }
LayerPtr getLabelLayer() { return inputLayers_[2]; }
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr);
void backward(const UpdateCallback& callback = nullptr) override;
void onPassEnd();
void onPassEnd() override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost) {
(void)output;
......@@ -214,17 +228,16 @@ class LambdaCost : public Layer {
public:
explicit LambdaCost(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; }
LayerPtr getScoreLayer() { return inputLayers_[1]; }
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr);
void onPassEnd();
void backward(const UpdateCallback& callback = nullptr) override;
real calcNDCG(const real* outputScore, const real* score, int size);
void calcGrad(const real* outputScore,
......@@ -256,11 +269,14 @@ public:
explicit MultiBinaryLabelCrossEntropy(const LayerConfig& config)
: CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost);
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad);
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
};
/**
......@@ -282,13 +298,16 @@ class HuberTwoClass : public CostLayer {
public:
explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forwardImp(Matrix& output, Argument& label, Matrix& cost);
void forwardImp(Matrix& output, Argument& label, Matrix& cost) override;
void forwardImpIn(Matrix& output, Argument& label, Matrix& cost);
void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad);
void backwardImp(Matrix& outputValue,
Argument& label,
Matrix& outputGrad) override;
void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad);
};
......
......@@ -35,14 +35,15 @@ public:
~CudnnBatchNormLayer();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/**
* reshape tensor of ioDesc_.
*/
void reshape(int batchSize);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
/**
......
......@@ -45,9 +45,10 @@ public:
~CudnnConvLayer();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
void addBiases();
void bpropBiases();
};
......
......@@ -45,7 +45,8 @@ public:
hl_pooling_mode_t* mode = nullptr);
explicit CudnnPoolLayer(const LayerConfig& config);
~CudnnPoolLayer();
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
/**
* Reshape input and output tensor descriptor.
......@@ -53,8 +54,8 @@ public:
* So reshaping is needed.
*/
void reshape(int batchSize);
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -33,13 +33,13 @@ public:
/**
* Prefetch sparse matrix/ids only.
*/
void prefetch() { output_ = data_; }
void prefetch() override { output_ = data_; }
/**
* Forward propagation. Copy data_ (value, in, grad, ids, cpuSequenceDims,
* sequenceStartPositions, subSequenceStartPositions, strs) to output_.
*/
virtual void forward(PassType passType) {
void forward(PassType passType) override {
Layer::forward(passType);
copyDataToOutput(output_);
if (FLAGS_show_layer_stat) {
......@@ -50,9 +50,9 @@ public:
/**
* Data layer's backward propagation do nothing.
*/
virtual void backward(const UpdateCallback& callback) { (void)callback; }
void backward(const UpdateCallback& callback) override { (void)callback; }
virtual void copyOutputToOtherDevice() {
void copyOutputToOtherDevice() override {
for (size_t i = 0; i != outputOtherDevice_.size(); i++) {
copyDataToOutput(outputOtherDevice_[i]);
}
......
......@@ -44,10 +44,11 @@ public:
~DataNormLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
int mode_;
......
......@@ -27,14 +27,14 @@ class EosIdCheckLayer : public Layer {
public:
explicit EosIdCheckLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size());
return ret;
}
virtual void forward(PassType passType) {
void forward(PassType passType) override {
Layer::forward(passType);
const Argument& input = getInput(0);
......@@ -42,7 +42,7 @@ public:
output_.ids->isEqualTo(*input.ids, config_.eos_id());
}
virtual void backward(const UpdateCallback& callback) {}
void backward(const UpdateCallback& callback) override {}
};
REGISTER_LAYER(eos_id, EosIdCheckLayer);
......
......@@ -48,7 +48,8 @@ public:
~ExpandConvBaseLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
size_t getOutputSize();
/**
......
......@@ -35,10 +35,11 @@ public:
~ExpandConvLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
};
} // namespace paddle
......@@ -34,10 +34,11 @@ public:
~ExpandConvTransLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
};
} // namespace paddle
......@@ -53,10 +53,11 @@ public:
~ExpandLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -46,10 +46,11 @@ public:
~FeatureMapExpandLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(featmap_expand, FeatureMapExpandLayer);
......
......@@ -36,13 +36,14 @@ public:
explicit FullyConnectedLayer(const LayerConfig& config) : Layer(config) {}
~FullyConnectedLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; }
void prefetch();
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void prefetch() override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -50,17 +50,18 @@ class GatedRecurrentLayer : public Layer, public GruCompute {
public:
explicit GatedRecurrentLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback);
void backward(const UpdateCallback& callback) override;
void resetState();
void resetState() override;
void setState(LayerStatePtr state);
void setState(LayerStatePtr state) override;
LayerStatePtr getState();
LayerStatePtr getState() override;
protected:
void forwardSequence(int batchSize,
......
......@@ -22,17 +22,18 @@ public:
~GetOutputLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
if (!Layer::init(layerMap, parameterMap)) return false;
CHECK_EQ(1U, inputLayers_.size());
CHECK_NE(inputArgument_[0], "");
return true;
}
void forward(PassType passType) {
void forward(PassType passType) override {
output_ = getPrev(0)->getOutput(inputArgument_[0]);
}
void backward(const UpdateCallback& callback = nullptr) {}
void backward(const UpdateCallback& callback = nullptr) override {}
};
REGISTER_LAYER(get_output, GetOutputLayer);
......
......@@ -55,10 +55,11 @@ public:
~GruStepLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(gru_step, GruStepLayer);
......
......@@ -61,9 +61,10 @@ class HierarchicalSigmoidLayer : public Layer {
public:
explicit HierarchicalSigmoidLayer(const LayerConfig& config)
: Layer(config) {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected:
/**
......
......@@ -43,10 +43,11 @@ public:
~InterpolationLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(interpolation, InterpolationLayer);
......
......@@ -74,17 +74,18 @@ class LstmLayer : public Layer, public LstmCompute {
public:
explicit LstmLayer(const LayerConfig &config) : Layer(config) {}
bool init(const LayerMap &layerMap, const ParameterMap &parameterMap);
bool init(const LayerMap &layerMap,
const ParameterMap &parameterMap) override;
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback &callback);
void backward(const UpdateCallback &callback) override;
void resetState();
void resetState() override;
void setState(LayerStatePtr state);
void setState(LayerStatePtr state) override;
LayerStatePtr getState();
LayerStatePtr getState() override;
protected:
/**
......
......@@ -35,10 +35,11 @@ public:
~LstmStepLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(lstm_step, LstmStepLayer);
......
......@@ -181,11 +181,12 @@ class MDLstmLayer : public LstmLayer {
public:
explicit MDLstmLayer(const LayerConfig& config) : LstmLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback);
void backward(const UpdateCallback& callback) override;
protected:
void forwardOneSequence(int start, CoordIterator& coordIter);
......
......@@ -30,8 +30,8 @@ private:
public:
explicit MaxIdLayer(const LayerConfig& config) : Layer(config) {}
virtual bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size());
......@@ -40,7 +40,7 @@ public:
return ret;
}
virtual void forward(PassType passType) {
void forward(PassType passType) override {
Layer::forward(passType);
const Argument& input = getInput(0);
size_t batchSize = input.getBatchSize();
......@@ -54,7 +54,7 @@ public:
input.value->rowMax(*output_.ids, *output_.in);
}
virtual void backward(const UpdateCallback& callback) {}
void backward(const UpdateCallback& callback) override {}
};
REGISTER_LAYER(maxid, MaxIdLayer);
......
......@@ -42,14 +42,13 @@ protected:
public:
explicit MaxLayer(const LayerConfig& config) : SequencePoolLayer(config) {}
~MaxLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
return SequencePoolLayer::init(layerMap, parameterMap);
}
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -45,10 +45,11 @@ public:
explicit MaxOutLayer(const LayerConfig& config) : Layer(config) {}
virtual ~MaxOutLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -35,21 +35,22 @@ public:
~MixedLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
virtual void prefetch();
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
virtual void resetState();
void prefetch() override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
void resetState() override;
/**
* setState() should be called after getState().
* Argument state consists of all projections states.
*/
virtual void setState(LayerStatePtr state);
void setState(LayerStatePtr state) override;
/**
* Return state which consists of all projections states.
*/
virtual LayerStatePtr getState();
LayerStatePtr getState() override;
protected:
std::vector<std::unique_ptr<Projection>> projections_;
......
......@@ -69,10 +69,11 @@ public:
~MultiplexLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
private:
/**
......
......@@ -61,7 +61,8 @@ public:
rand_(0, config.num_classes() - 1),
prepared_(false) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
......@@ -146,7 +147,7 @@ public:
prepared_ = true;
}
void prefetch() {
void prefetch() override {
prepareSamples();
IVector::resizeOrCreate(labelIds_, samples_.size(), useGpu_);
int* ids = labelIds_->getData();
......@@ -163,7 +164,7 @@ public:
}
}
void forward(PassType passType) {
void forward(PassType passType) override {
Layer::forward(passType);
CHECK(!useGpu_) << "GPU is not supported";
......@@ -199,7 +200,7 @@ public:
forwardCost();
}
void backward(const UpdateCallback& callback) {
void backward(const UpdateCallback& callback) override {
Matrix::resizeOrCreate(sampleOut_.grad,
1,
samples_.size(),
......
......@@ -30,7 +30,8 @@ class NormLayer : public Layer {
public:
explicit NormLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
Layer::init(layerMap, parameterMap);
return true;
}
......@@ -56,9 +57,10 @@ protected:
public:
explicit ResponseNormLayer(const LayerConfig& config) : NormLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType) { LOG(FATAL) << "Not implemented"; }
void backward(const UpdateCallback& callback = nullptr) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override { LOG(FATAL) << "Not implemented"; }
void backward(const UpdateCallback& callback = nullptr) override {
LOG(FATAL) << "Not implemented";
}
};
......
......@@ -36,9 +36,10 @@ public:
size_t getSize();
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
TensorShape shape_;
......
......@@ -38,10 +38,11 @@ public:
~OuterProdLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(out_prod, OuterProdLayer);
......
......@@ -29,9 +29,10 @@ public:
~PadLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
void setOutDims(const size_t batchSize);
......
......@@ -56,9 +56,10 @@ public:
~ParameterReluLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -46,7 +46,8 @@ public:
*/
static Layer* create(const LayerConfig& config);
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
};
} // namespace paddle
......@@ -40,7 +40,7 @@ public:
size_t getSize();
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -40,10 +40,11 @@ public:
~PowerLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(power, PowerLayer);
......
......@@ -19,8 +19,8 @@ namespace paddle {
class PrintLayer : public Layer {
public:
explicit PrintLayer(const LayerConfig& config) : Layer(config) {}
void forward(PassType passType);
void backward(const UpdateCallback& callback) {}
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override {}
};
void PrintLayer::forward(PassType passType) {
......
......@@ -30,10 +30,11 @@ namespace paddle {
class PriorBoxLayer : public Layer {
public:
explicit PriorBoxLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback) {}
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override {}
protected:
int numPriors_;
......
......@@ -45,17 +45,18 @@ class RecurrentLayer : public Layer {
public:
explicit RecurrentLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback);
void backward(const UpdateCallback& callback) override;
void resetState();
void resetState() override;
void setState(LayerStatePtr state);
void setState(LayerStatePtr state) override;
LayerStatePtr getState();
LayerStatePtr getState() override;
protected:
/**
......
......@@ -33,15 +33,15 @@ public:
void initSubNetwork(NeuralNetwork* rootNetwork,
const ModelConfig& config,
const std::vector<ParameterType>& parameterTypes,
bool useGpu);
bool useGpu) override;
void forward(PassType passType) {
void forward(PassType passType) override {
REGISTER_TIMER_INFO("RecurrentGroupFwTime", getName().c_str());
const std::vector<Argument> inArgs;
std::vector<Argument> outArgs;
network_->forward(inArgs, &outArgs, passType);
}
void backward(const UpdateCallback& callback) {
void backward(const UpdateCallback& callback) override {
REGISTER_TIMER_INFO("RecurrentGroupBwTime", getName().c_str());
network_->backward(nullptr);
......@@ -53,7 +53,8 @@ public:
/**
* @see Layer.accessSubNetwork
*/
void accessSubNetwork(const std::function<void(NeuralNetwork&)>& callback) {
void accessSubNetwork(
const std::function<void(NeuralNetwork&)>& callback) override {
callback(*network_);
}
......
......@@ -20,18 +20,19 @@ namespace paddle {
/**
* @brief A layer for resizing a minibatch matrix h*w to h'*w'
* @note
* origin matrix height * witdth)
* origin matrix height * width)
* resize matrix: (height * width / size) * size
*/
class ResizeLayer : public Layer {
public:
explicit ResizeLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback);
void backward(const UpdateCallback& callback) override;
};
REGISTER_LAYER(resize, ResizeLayer);
......
......@@ -35,8 +35,8 @@ public:
explicit SamplingIdLayer(const LayerConfig& config)
: Layer(config), rand1_(0, 1) {}
virtual bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override {
bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1UL, inputLayers_.size());
if (useGpu_) {
......@@ -48,7 +48,7 @@ public:
return ret;
}
void forward(PassType passType) {
void forward(PassType passType) override {
Layer::forward(passType);
if (useGpu_) {
for (size_t i = 0; i < inputLayers_.size(); i++) {
......@@ -83,7 +83,7 @@ public:
output_.ids->copyFrom(ids.data(), batchSize);
}
virtual void backward(const UpdateCallback& callback) {}
void backward(const UpdateCallback& callback) override {}
};
REGISTER_LAYER(sampling_id, SamplingIdLayer);
......
......@@ -37,10 +37,11 @@ public:
~ScalingLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(scaling, ScalingLayer);
......
......@@ -65,9 +65,10 @@ public:
: Layer(config), selCols_(nullptr) {}
~SelectiveFullyConnectedLayer() {}
void prefetch();
void prefetch() override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; }
......@@ -90,8 +91,8 @@ public:
void fillSelectiveData(
const std::shared_ptr<std::vector<std::pair<int*, size_t>>>& candidates);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
private:
/**
......
......@@ -35,10 +35,11 @@ public:
~SequenceConcatLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(seqconcat, SequenceConcatLayer);
......
......@@ -42,12 +42,11 @@ public:
explicit SequenceLastInstanceLayer(const LayerConfig& config)
: SequencePoolLayer(config) {}
~SequenceLastInstanceLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer);
......
......@@ -46,12 +46,11 @@ protected:
public:
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}
virtual ~SequencePoolLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -34,12 +34,11 @@ protected:
public:
explicit SequenceReshapeLayer(const LayerConfig& config) : Layer(config) {}
~SequenceReshapeLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(seqreshape, SequenceReshapeLayer);
......
......@@ -39,12 +39,11 @@ class SlopeInterceptLayer : public Layer {
public:
explicit SlopeInterceptLayer(const LayerConfig& config) : Layer(config) {}
~SlopeInterceptLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(slope_intercept, SlopeInterceptLayer);
......
......@@ -43,9 +43,8 @@ protected:
public:
explicit SpatialPyramidPoolLayer(const LayerConfig& config) : Layer(config) {}
~SpatialPyramidPoolLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
ProjectionConfig getConfig(size_t sizeX_,
size_t sizeY_,
......@@ -54,7 +53,7 @@ public:
std::string& poolType_);
size_t getSize();
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -35,12 +35,11 @@ protected:
public:
explicit SubSequenceLayer(const LayerConfig& config) : Layer(config) {}
~SubSequenceLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(subseq, SubSequenceLayer);
......
......@@ -41,12 +41,11 @@ protected:
public:
explicit SumToOneNormLayer(const LayerConfig& config) : Layer(config) {}
~SumToOneNormLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(sum_to_one_norm, SumToOneNormLayer);
......
......@@ -44,13 +44,12 @@ protected:
public:
explicit TensorLayer(const LayerConfig& config) : Layer(config) {}
~TensorLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
Weight& getWeight(int idx) { return *weights_[idx]; }
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -32,9 +32,10 @@ class TransLayer : public Layer {
public:
explicit TransLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -26,7 +26,8 @@ class ValidationLayer : public Layer {
public:
explicit ValidationLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
LayerPtr getOutputLayer() { return inputLayers_[0]; }
......@@ -37,13 +38,13 @@ public:
return inputLayers_[2];
}
virtual void forward(PassType passType);
void forward(PassType passType) override;
virtual void backward(const UpdateCallback& callback = nullptr);
void backward(const UpdateCallback& callback = nullptr) override;
virtual void validationImp(MatrixPtr outputValue, IVectorPtr label) = 0;
virtual void onPassEnd() = 0;
void onPassEnd() override = 0;
};
/*
......@@ -57,11 +58,12 @@ public:
cpuLabel_(nullptr),
cpuWeight_(nullptr) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void validationImp(MatrixPtr outputValue, IVectorPtr label);
void validationImp(MatrixPtr outputValue, IVectorPtr label) override;
void onPassEnd();
void onPassEnd() override;
struct PredictionResult {
PredictionResult(real __out, int __label) : out(__out), label(__label) {}
......@@ -86,11 +88,12 @@ public:
explicit PnpairValidation(const LayerConfig& config)
: ValidationLayer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void validationImp(MatrixPtr outputValue, IVectorPtr label);
void validationImp(MatrixPtr outputValue, IVectorPtr label) override;
void onPassEnd();
void onPassEnd() override;
private:
bool passBegin_;
......
......@@ -30,9 +30,10 @@ public:
explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {}
~WarpCTCLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback);
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
protected:
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册