提交 9bc12034 编写于 作者: Y Yu Yang

Add more comments, also add __must_check.

上级 3d01c60e
...@@ -69,11 +69,11 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar; ...@@ -69,11 +69,11 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
class IdentityActivation : public ActivationFunction { class IdentityActivation : public ActivationFunction {
public: public:
static const std::string name; static const std::string name;
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
(void)act; (void)act;
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
(void)act; (void)act;
return Status(); return Status();
} }
...@@ -92,11 +92,11 @@ static InitFunction __reg_activation__identity([] { ...@@ -92,11 +92,11 @@ static InitFunction __reg_activation__identity([] {
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(sigmoid) BEGIN_DEFINE_ACTIVATION(sigmoid)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->sigmoid(*act.value); act.value->sigmoid(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->sigmoidDerivative(*act.value); act.grad->sigmoidDerivative(*act.value);
return Status(); return Status();
} }
...@@ -115,12 +115,12 @@ MatrixPtr sftMaxDot_; ...@@ -115,12 +115,12 @@ MatrixPtr sftMaxDot_;
MatrixPtr one_; MatrixPtr one_;
public: public:
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->softmax(*act.value); act.value->softmax(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
MatrixPtr outputV = act.value; MatrixPtr outputV = act.value;
MatrixPtr outputG = act.grad; MatrixPtr outputG = act.grad;
...@@ -167,7 +167,7 @@ ACTIVATION_CLASS_NAME(softmax) softmax_; ...@@ -167,7 +167,7 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
Argument argument_; Argument argument_;
public: public:
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
if (act.value->getWidth() != 1UL) { if (act.value->getWidth() != 1UL) {
return Status( return Status(
"Input width for each timestep of sequence softmax should be 1"); "Input width for each timestep of sequence softmax should be 1");
...@@ -191,7 +191,7 @@ Status forward(Argument& act) { ...@@ -191,7 +191,7 @@ Status forward(Argument& act) {
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
if (act.value->getWidth() != 1UL) { if (act.value->getWidth() != 1UL) {
return Status( return Status(
"Input width for each timestep of sequence softmax should be 1"); "Input width for each timestep of sequence softmax should be 1");
...@@ -207,7 +207,8 @@ Status backward(Argument& act) { ...@@ -207,7 +207,8 @@ Status backward(Argument& act) {
argument_.value->setData(act.value->getData() + offset, 1UL, size); argument_.value->setData(act.value->getData() + offset, 1UL, size);
argument_.grad->setData(act.grad->getData() + offset, 1UL, size); argument_.grad->setData(act.grad->getData() + offset, 1UL, size);
softmax_.backward(argument_); Status status = softmax_.backward(argument_);
if (!status.isOK()) return status;
} }
return Status(); return Status();
} }
...@@ -224,12 +225,12 @@ END_DEFINE_ACTIVATION(sequence_softmax) ...@@ -224,12 +225,12 @@ END_DEFINE_ACTIVATION(sequence_softmax)
* 0 otherwise. * 0 otherwise.
*/ */
BEGIN_DEFINE_ACTIVATION(relu) BEGIN_DEFINE_ACTIVATION(relu)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->relu(*act.value); act.value->relu(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->reluDerivative(*act.value); act.grad->reluDerivative(*act.value);
return Status(); return Status();
} }
...@@ -249,12 +250,12 @@ END_DEFINE_ACTIVATION(relu) ...@@ -249,12 +250,12 @@ END_DEFINE_ACTIVATION(relu)
* TODO(yuyang18): Remove magic number 24 or make it configuable. * TODO(yuyang18): Remove magic number 24 or make it configuable.
*/ */
BEGIN_DEFINE_ACTIVATION(brelu) BEGIN_DEFINE_ACTIVATION(brelu)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->brelu(*act.value); act.value->brelu(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->breluDerivative(*act.value); act.grad->breluDerivative(*act.value);
return Status(); return Status();
} }
...@@ -267,12 +268,12 @@ END_DEFINE_ACTIVATION(brelu) ...@@ -267,12 +268,12 @@ END_DEFINE_ACTIVATION(brelu)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(tanh) BEGIN_DEFINE_ACTIVATION(tanh)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->tanh(*act.value); act.value->tanh(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->tanhDerivative(*act.value); act.grad->tanhDerivative(*act.value);
return Status(); return Status();
} }
...@@ -290,12 +291,12 @@ real a, b; ...@@ -290,12 +291,12 @@ real a, b;
public: public:
ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {} ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {}
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->scaledTanh(*act.value, a, b); act.value->scaledTanh(*act.value, a, b);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->scaledTanhDerivative(*act.value, a, b); act.grad->scaledTanhDerivative(*act.value, a, b);
return Status(); return Status();
} }
...@@ -308,12 +309,12 @@ END_DEFINE_ACTIVATION(stanh) ...@@ -308,12 +309,12 @@ END_DEFINE_ACTIVATION(stanh)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(softrelu) BEGIN_DEFINE_ACTIVATION(softrelu)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->softrelu(*act.value); act.value->softrelu(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->softreluDerivative(*act.value); act.grad->softreluDerivative(*act.value);
return Status(); return Status();
} }
...@@ -332,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu) ...@@ -332,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
* 0 if z=0 * 0 if z=0
*/ */
BEGIN_DEFINE_ACTIVATION(abs) BEGIN_DEFINE_ACTIVATION(abs)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
SetDevice device(act.deviceId); SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in, Matrix::resizeOrCreate(act.in,
act.value->getHeight(), act.value->getHeight(),
...@@ -345,7 +346,7 @@ Status forward(Argument& act) { ...@@ -345,7 +346,7 @@ Status forward(Argument& act) {
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->absDerivative(*act.in); act.grad->absDerivative(*act.in);
return Status(); return Status();
} }
...@@ -358,7 +359,7 @@ END_DEFINE_ACTIVATION(abs) ...@@ -358,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(square) BEGIN_DEFINE_ACTIVATION(square)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
SetDevice device(act.deviceId); SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in, Matrix::resizeOrCreate(act.in,
act.value->getHeight(), act.value->getHeight(),
...@@ -371,7 +372,7 @@ Status forward(Argument& act) { ...@@ -371,7 +372,7 @@ Status forward(Argument& act) {
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->squareDerivative(*act.in); act.grad->squareDerivative(*act.in);
return Status(); return Status();
} }
...@@ -384,12 +385,12 @@ END_DEFINE_ACTIVATION(square) ...@@ -384,12 +385,12 @@ END_DEFINE_ACTIVATION(square)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(exponential) BEGIN_DEFINE_ACTIVATION(exponential)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
act.value->exp2(*act.value); act.value->exp2(*act.value);
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->expDerivative(*act.value); act.grad->expDerivative(*act.value);
return Status(); return Status();
} }
...@@ -402,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential) ...@@ -402,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(log) BEGIN_DEFINE_ACTIVATION(log)
Status forward(Argument& act) { Status __must_check forward(Argument& act) {
SetDevice device(act.deviceId); SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in, Matrix::resizeOrCreate(act.in,
act.value->getHeight(), act.value->getHeight(),
...@@ -415,7 +416,7 @@ Status forward(Argument& act) { ...@@ -415,7 +416,7 @@ Status forward(Argument& act) {
return Status(); return Status();
} }
Status backward(Argument& act) { Status __must_check backward(Argument& act) {
act.grad->dotDiv(*act.grad, *act.in); act.grad->dotDiv(*act.grad, *act.in);
return Status(); return Status();
} }
......
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
* *
* Usually, act is Layer::output_ * Usually, act is Layer::output_
*/ */
virtual Status forward(Argument& act) = 0; virtual Status __must_check forward(Argument& act) = 0;
/** /**
* @brief Backward propagaion * @brief Backward propagaion
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
* - Before calling backward(), act.grad = dE / dy, where E is the error/cost * - Before calling backward(), act.grad = dE / dy, where E is the error/cost
* - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx) * - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx)
*/ */
virtual Status backward(Argument& act) = 0; virtual Status __must_check backward(Argument& act) = 0;
virtual const std::string& getName() const = 0; virtual const std::string& getName() const = 0;
}; };
......
...@@ -336,7 +336,7 @@ void Layer::showOutputStats() { ...@@ -336,7 +336,7 @@ void Layer::showOutputStats() {
void Layer::forwardActivation() { void Layer::forwardActivation() {
/* activation */ /* activation */
auto status = activation_->forward(output_); auto status = activation_->forward(output_);
CHECK(status.isOK()) << status.what(); status.check();
/* dropout */ /* dropout */
if (config_.drop_rate() > 0) { if (config_.drop_rate() > 0) {
...@@ -375,7 +375,7 @@ void Layer::backwardActivation() { ...@@ -375,7 +375,7 @@ void Layer::backwardActivation() {
} }
auto status = activation_->backward(output_); auto status = activation_->backward(output_);
CHECK(status.isOK()) << status.what(); status.check();
} }
void Layer::forwardDropOut() { void Layer::forwardDropOut() {
......
...@@ -506,9 +506,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start, ...@@ -506,9 +506,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start,
*frameState_[start + preOffsetV[i]].value, *checkFgOneDim, 1.0, 1.0); *frameState_[start + preOffsetV[i]].value, *checkFgOneDim, 1.0, 1.0);
} }
} }
activationGate_->forward(frameInputGate_[idxCurr]); auto status = activationGate_->forward(frameInputGate_[idxCurr]);
activationGate_->forward(frameForgetGate_[idxCurr]); status.check();
activation_->forward(frameInputNode_[idxCurr]); status = activationGate_->forward(frameForgetGate_[idxCurr]);
status.check();
status = activation_->forward(frameInputNode_[idxCurr]);
status.check();
frameState_[idxCurr].value->zeroMem(); frameState_[idxCurr].value->zeroMem();
for (int i = 0; i < numDims_; i++) { for (int i = 0; i < numDims_; i++) {
...@@ -530,10 +533,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start, ...@@ -530,10 +533,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start,
frameOutputGate_[idxCurr].value->addDotMul( frameOutputGate_[idxCurr].value->addDotMul(
*frameState_[idxCurr].value, *checkOg_, 1.0, 1.0); *frameState_[idxCurr].value, *checkOg_, 1.0, 1.0);
activationGate_->forward(frameOutputGate_[idxCurr]); status = activationGate_->forward(frameOutputGate_[idxCurr]);
status.check();
framePreOutput_[idxCurr].value->copyFrom(*(frameState_[idxCurr].value)); framePreOutput_[idxCurr].value->copyFrom(*(frameState_[idxCurr].value));
activationState_->forward(framePreOutput_[idxCurr]); status = activationState_->forward(framePreOutput_[idxCurr]);
status.check();
frameOutput_[idxCurr].value->dotMul(*framePreOutput_[idxCurr].value, frameOutput_[idxCurr].value->dotMul(*framePreOutput_[idxCurr].value,
*frameOutputGate_[idxCurr].value); *frameOutputGate_[idxCurr].value);
...@@ -640,12 +645,12 @@ void MDLstmLayer::backwardGate2OutputSequence(int start, ...@@ -640,12 +645,12 @@ void MDLstmLayer::backwardGate2OutputSequence(int start,
framePreOutput_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad, framePreOutput_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad,
*frameOutputGate_[idxCurr].value); *frameOutputGate_[idxCurr].value);
activationState_->backward(framePreOutput_[idxCurr]); activationState_->backward(framePreOutput_[idxCurr]).check();
frameState_[idxCurr].grad->copyFrom(*(framePreOutput_[idxCurr].grad)); frameState_[idxCurr].grad->copyFrom(*(framePreOutput_[idxCurr].grad));
frameOutputGate_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad, frameOutputGate_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad,
*framePreOutput_[idxCurr].value); *framePreOutput_[idxCurr].value);
activationGate_->backward(frameOutputGate_[idxCurr]); activationGate_->backward(frameOutputGate_[idxCurr]).check();
frameState_[idxCurr].grad->addDotMul( frameState_[idxCurr].grad->addDotMul(
*frameOutputGate_[idxCurr].grad, *checkOg_, 1.0, 1.0); *frameOutputGate_[idxCurr].grad, *checkOg_, 1.0, 1.0);
...@@ -702,9 +707,9 @@ void MDLstmLayer::backwardGate2OutputSequence(int start, ...@@ -702,9 +707,9 @@ void MDLstmLayer::backwardGate2OutputSequence(int start,
} }
} }
activationGate_->backward(frameInputGate_[idxCurr]); activationGate_->backward(frameInputGate_[idxCurr]).check();
activationGate_->backward(frameForgetGate_[idxCurr]); activationGate_->backward(frameForgetGate_[idxCurr]).check();
activation_->backward(frameInputNode_[idxCurr]); activation_->backward(frameInputNode_[idxCurr]).check();
if (bias_->getWGrad()) { if (bias_->getWGrad()) {
for (int i = 0; i < numDims_; i++) { for (int i = 0; i < numDims_; i++) {
......
...@@ -193,7 +193,8 @@ public: ...@@ -193,7 +193,8 @@ public:
forwardOneInput(l); forwardOneInput(l);
} }
activation_->forward(sampleOut_); auto status = activation_->forward(sampleOut_);
status.check();
forwardCost(); forwardCost();
} }
...@@ -207,7 +208,8 @@ public: ...@@ -207,7 +208,8 @@ public:
backwardCost(); backwardCost();
activation_->backward(sampleOut_); auto status = activation_->backward(sampleOut_);
status.check();
if (biases_->getWGrad()) { if (biases_->getWGrad()) {
backwardBias(callback); backwardBias(callback);
......
...@@ -217,21 +217,22 @@ void RecurrentLayer::forwardOneSequence(int start, int length) { ...@@ -217,21 +217,22 @@ void RecurrentLayer::forwardOneSequence(int start, int length) {
if (prevOutput_) { if (prevOutput_) {
frameOutput_[start].value->mul(*prevOutput_, *weight_->getW(), 1, 1); frameOutput_[start].value->mul(*prevOutput_, *weight_->getW(), 1, 1);
} }
activation_->forward(frameOutput_[start]); activation_->forward(frameOutput_[start]).check();
for (int i = 1; i < length; ++i) { for (int i = 1; i < length; ++i) {
frameOutput_[start + i].value->mul( frameOutput_[start + i].value->mul(
*frameOutput_[start + i - 1].value, *weight_->getW(), 1, 1); *frameOutput_[start + i - 1].value, *weight_->getW(), 1, 1);
activation_->forward(frameOutput_[start + i]); activation_->forward(frameOutput_[start + i]).check();
} }
if (prevOutput_) { if (prevOutput_) {
prevOutput_->assign(*frameOutput_[start + length - 1].value); prevOutput_->assign(*frameOutput_[start + length - 1].value);
} }
} else { } else {
activation_->forward(frameOutput_[start + length - 1]); activation_->forward(frameOutput_[start + length - 1]).check();
for (int i = length - 2; i >= 0; --i) { for (int i = length - 2; i >= 0; --i) {
frameOutput_[start + i].value->mul( frameOutput_[start + i].value->mul(
*frameOutput_[start + i + 1].value, *weight_->getW(), 1, 1); *frameOutput_[start + i + 1].value, *weight_->getW(), 1, 1);
activation_->forward(frameOutput_[start + i]); activation_->forward(frameOutput_[start + i]).check();
} }
} }
} }
...@@ -280,11 +281,11 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { ...@@ -280,11 +281,11 @@ void RecurrentLayer::backwardOneSequence(int start, int length) {
MatrixPtr weightT = weight_->getW()->getTranspose(); MatrixPtr weightT = weight_->getW()->getTranspose();
if (!reversed_) { if (!reversed_) {
for (int i = length - 1; i > 0; --i) { for (int i = length - 1; i > 0; --i) {
activation_->backward(frameOutput_[start + i]); activation_->backward(frameOutput_[start + i]).check();
frameOutput_[start + i - 1].grad->mul( frameOutput_[start + i - 1].grad->mul(
*frameOutput_[start + i].grad, *weightT, 1, 1); *frameOutput_[start + i].grad, *weightT, 1, 1);
} }
activation_->backward(frameOutput_[start]); activation_->backward(frameOutput_[start]).check();
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
*output_.value->subMatrix(start, length - 1)->getTranspose(), *output_.value->subMatrix(start, length - 1)->getTranspose(),
...@@ -294,11 +295,11 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { ...@@ -294,11 +295,11 @@ void RecurrentLayer::backwardOneSequence(int start, int length) {
} }
} else { } else {
for (int i = 0; i < length - 1; ++i) { for (int i = 0; i < length - 1; ++i) {
activation_->backward(frameOutput_[start + i]); activation_->backward(frameOutput_[start + i]).check();
frameOutput_[start + i + 1].grad->mul( frameOutput_[start + i + 1].grad->mul(
*frameOutput_[start + i].grad, *weightT, 1, 1); *frameOutput_[start + i].grad, *weightT, 1, 1);
} }
activation_->backward(frameOutput_[start + length - 1]); activation_->backward(frameOutput_[start + length - 1]).check();
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
*output_.value->subMatrix(start + 1, length - 1)->getTranspose(), *output_.value->subMatrix(start + 1, length - 1)->getTranspose(),
...@@ -333,7 +334,7 @@ void RecurrentLayer::forwardBatch(int batchSize, ...@@ -333,7 +334,7 @@ void RecurrentLayer::forwardBatch(int batchSize,
} }
Argument arg; Argument arg;
arg.value = batch2; arg.value = batch2;
activation_->forward(arg); activation_->forward(arg).check();
} }
} }
batchValue_->copyBackSeq(*output_.value); batchValue_->copyBackSeq(*output_.value);
...@@ -363,7 +364,7 @@ void RecurrentLayer::backwardBatch(int batchSize, ...@@ -363,7 +364,7 @@ void RecurrentLayer::backwardBatch(int batchSize,
Argument arg; Argument arg;
arg.value = batch1; arg.value = batch1;
arg.grad = batch2; arg.grad = batch2;
activation_->backward(arg); activation_->backward(arg).check();
if (n != 0) { if (n != 0) {
batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight()); batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight());
......
...@@ -192,7 +192,8 @@ void SelectiveFullyConnectedLayer::forward(PassType passType) { ...@@ -192,7 +192,8 @@ void SelectiveFullyConnectedLayer::forward(PassType passType) {
nnz, nnz,
/*trans=*/false, /*trans=*/false,
/*useGpu=*/useGpu_); /*useGpu=*/useGpu_);
activation_->forward(arg); //! TODO(yuyang18): Why we cannot invoke forwardActivation here?
activation_->forward(arg).check();
} else /* train and test in train, not generating */ { } else /* train and test in train, not generating */ {
// during training, this layer output value is *Matrix*, which is input of // during training, this layer output value is *Matrix*, which is input of
// eg. multi-class-cross-entropy // eg. multi-class-cross-entropy
......
...@@ -148,11 +148,11 @@ LayerPtr createCTCLayer(string name, ...@@ -148,11 +148,11 @@ LayerPtr createCTCLayer(string name,
ActivationFunction* softmaxActivation = ActivationFunction::create("softmax"); ActivationFunction* softmaxActivation = ActivationFunction::create("softmax");
softmaxActivation->forward(dataLayer->getOutput()); softmaxActivation->forward(dataLayer->getOutput()).check();
layer->forward(PASS_GC); layer->forward(PASS_GC);
layer->backward(); layer->backward();
softmaxActivation->backward(dataLayer->getOutput()); softmaxActivation->backward(dataLayer->getOutput()).check();
return layer; return layer;
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef __GNUC__
#define GCC_VERSION \
(__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#else
#define GCC_VERSION
#endif
#if GCC_VERSION >= 30400
#define __must_check __attribute__((warn_unused_result))
#else
#define __must_check
#endif
...@@ -14,9 +14,11 @@ limitations under the License. */ ...@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h>
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include "Compiler.h"
namespace paddle { namespace paddle {
...@@ -29,8 +31,55 @@ namespace paddle { ...@@ -29,8 +31,55 @@ namespace paddle {
* There are two styles to return status in Paddle. * There are two styles to return status in Paddle.
* *
* 1. Return Status * 1. Return Status
* When method return a status, the return must use `__must_check` attribute.
* Example as below.
* @code{cpp}
* Status __must_check foo();
* *
* Status __must_check bar() {
* // do something.
* Status s = foo(); // invoke other method return status.
* if (!s.isOK()) return s;
* // do something else.
* return Status();
* }
* @endcode{cpp}
* *
* 2. Return by parameter.
* It is another way to return a status, by using a pointer parameter.
* Example as below.
*
* @code{cpp}
* Status bar();
*
* int foo(Status* status) {
* // Do something.
* Status s = bar();
* if (!s.isOK()) {
* *status = s;
* return 0;
* }
* // Do something else.
* if (someInternalErrorHappend) {
* status->setByPrintf("Some dimension is too large, %d", dimension);
* return 0;
* }
* // End of method.
* return someValue;
* }
*
* Status foobar() {
* Status s;
* // do something.
* foo(&s);
* if (!s.isOK()) return s;
* }
* @endcode{cpp}
*
*
* Currently there is a helper method 'check' in status, because Paddle always
* use log(FATAL) or CHECK to make program exit before. When we clean all
* log(FATAL) and CHECK in Paddle, 'check' method will be removed.
*/ */
class Status final : public std::exception { class Status final : public std::exception {
public: public:
...@@ -92,6 +141,13 @@ public: ...@@ -92,6 +141,13 @@ public:
*/ */
inline bool isOK() const noexcept { return errMsg_ == nullptr; } inline bool isOK() const noexcept { return errMsg_ == nullptr; }
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
inline void check() const { CHECK(isOK()) << what(); }
private: private:
std::shared_ptr<std::string> errMsg_; std::shared_ptr<std::string> errMsg_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册