提交 ec790e10 编写于 作者: Y Yu Yang

Rename Status => Error.

* Also make ErrorF as a global method.
上级 8605544c
...@@ -69,13 +69,13 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar; ...@@ -69,13 +69,13 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
class IdentityActivation : public ActivationFunction { class IdentityActivation : public ActivationFunction {
public: public:
static const std::string name; static const std::string name;
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
(void)act; (void)act;
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
(void)act; (void)act;
return Status(); return Error();
} }
const std::string& getName() const { return name; } const std::string& getName() const { return name; }
}; };
...@@ -92,13 +92,13 @@ static InitFunction __reg_activation__identity([] { ...@@ -92,13 +92,13 @@ static InitFunction __reg_activation__identity([] {
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(sigmoid) BEGIN_DEFINE_ACTIVATION(sigmoid)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->sigmoid(*act.value); act.value->sigmoid(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->sigmoidDerivative(*act.value); act.grad->sigmoidDerivative(*act.value);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(sigmoid) END_DEFINE_ACTIVATION(sigmoid)
...@@ -115,12 +115,12 @@ MatrixPtr sftMaxDot_; ...@@ -115,12 +115,12 @@ MatrixPtr sftMaxDot_;
MatrixPtr one_; MatrixPtr one_;
public: public:
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->softmax(*act.value); act.value->softmax(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
MatrixPtr outputV = act.value; MatrixPtr outputV = act.value;
MatrixPtr outputG = act.grad; MatrixPtr outputG = act.grad;
...@@ -152,7 +152,7 @@ Status __must_check backward(Argument& act) { ...@@ -152,7 +152,7 @@ Status __must_check backward(Argument& act) {
act.grad->softmaxDerivative(*act.value, *sftMaxSum_); act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
} }
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(softmax) END_DEFINE_ACTIVATION(softmax)
...@@ -167,9 +167,9 @@ ACTIVATION_CLASS_NAME(softmax) softmax_; ...@@ -167,9 +167,9 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
Argument argument_; Argument argument_;
public: public:
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
if (act.value->getWidth() != 1UL) { if (act.value->getWidth() != 1UL) {
return Status( return ErrorF(
"Input width for each timestep of sequence softmax should be 1"); "Input width for each timestep of sequence softmax should be 1");
} }
...@@ -188,12 +188,12 @@ Status __must_check forward(Argument& act) { ...@@ -188,12 +188,12 @@ Status __must_check forward(Argument& act) {
auto starts = act.sequenceStartPositions->getVector(useGpu(act.deviceId)); auto starts = act.sequenceStartPositions->getVector(useGpu(act.deviceId));
act.value->sequenceSoftmax(*act.value, *starts); act.value->sequenceSoftmax(*act.value, *starts);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
if (act.value->getWidth() != 1UL) { if (act.value->getWidth() != 1UL) {
return Status( return ErrorF(
"Input width for each timestep of sequence softmax should be 1"); "Input width for each timestep of sequence softmax should be 1");
} }
...@@ -207,10 +207,10 @@ Status __must_check backward(Argument& act) { ...@@ -207,10 +207,10 @@ Status __must_check backward(Argument& act) {
argument_.value->setData(act.value->getData() + offset, 1UL, size); argument_.value->setData(act.value->getData() + offset, 1UL, size);
argument_.grad->setData(act.grad->getData() + offset, 1UL, size); argument_.grad->setData(act.grad->getData() + offset, 1UL, size);
Status status = softmax_.backward(argument_); Error status = softmax_.backward(argument_);
if (!status.isOK()) return status; if (!status.isOK()) return status;
} }
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(sequence_softmax) END_DEFINE_ACTIVATION(sequence_softmax)
...@@ -225,14 +225,14 @@ END_DEFINE_ACTIVATION(sequence_softmax) ...@@ -225,14 +225,14 @@ END_DEFINE_ACTIVATION(sequence_softmax)
* 0 otherwise. * 0 otherwise.
*/ */
BEGIN_DEFINE_ACTIVATION(relu) BEGIN_DEFINE_ACTIVATION(relu)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->relu(*act.value); act.value->relu(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->reluDerivative(*act.value); act.grad->reluDerivative(*act.value);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(relu) END_DEFINE_ACTIVATION(relu)
...@@ -250,14 +250,14 @@ END_DEFINE_ACTIVATION(relu) ...@@ -250,14 +250,14 @@ END_DEFINE_ACTIVATION(relu)
* TODO(yuyang18): Remove magic number 24 or make it configuable. * TODO(yuyang18): Remove magic number 24 or make it configuable.
*/ */
BEGIN_DEFINE_ACTIVATION(brelu) BEGIN_DEFINE_ACTIVATION(brelu)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->brelu(*act.value); act.value->brelu(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->breluDerivative(*act.value); act.grad->breluDerivative(*act.value);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(brelu) END_DEFINE_ACTIVATION(brelu)
...@@ -268,14 +268,14 @@ END_DEFINE_ACTIVATION(brelu) ...@@ -268,14 +268,14 @@ END_DEFINE_ACTIVATION(brelu)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(tanh) BEGIN_DEFINE_ACTIVATION(tanh)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->tanh(*act.value); act.value->tanh(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->tanhDerivative(*act.value); act.grad->tanhDerivative(*act.value);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(tanh) END_DEFINE_ACTIVATION(tanh)
...@@ -291,14 +291,14 @@ real a, b; ...@@ -291,14 +291,14 @@ real a, b;
public: public:
ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {} ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {}
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->scaledTanh(*act.value, a, b); act.value->scaledTanh(*act.value, a, b);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->scaledTanhDerivative(*act.value, a, b); act.grad->scaledTanhDerivative(*act.value, a, b);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(stanh) END_DEFINE_ACTIVATION(stanh)
...@@ -309,14 +309,14 @@ END_DEFINE_ACTIVATION(stanh) ...@@ -309,14 +309,14 @@ END_DEFINE_ACTIVATION(stanh)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(softrelu) BEGIN_DEFINE_ACTIVATION(softrelu)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->softrelu(*act.value); act.value->softrelu(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->softreluDerivative(*act.value); act.grad->softreluDerivative(*act.value);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(softrelu) END_DEFINE_ACTIVATION(softrelu)
...@@ -333,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu) ...@@ -333,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
* 0 if z=0 * 0 if z=0
*/ */
BEGIN_DEFINE_ACTIVATION(abs) BEGIN_DEFINE_ACTIVATION(abs)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId); SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in, Matrix::resizeOrCreate(act.in,
act.value->getHeight(), act.value->getHeight(),
...@@ -343,12 +343,12 @@ Status __must_check forward(Argument& act) { ...@@ -343,12 +343,12 @@ Status __must_check forward(Argument& act) {
act.in->copyFrom(*act.value); act.in->copyFrom(*act.value);
act.value->abs2(*act.value); act.value->abs2(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->absDerivative(*act.in); act.grad->absDerivative(*act.in);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(abs) END_DEFINE_ACTIVATION(abs)
...@@ -359,7 +359,7 @@ END_DEFINE_ACTIVATION(abs) ...@@ -359,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(square) BEGIN_DEFINE_ACTIVATION(square)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId); SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in, Matrix::resizeOrCreate(act.in,
act.value->getHeight(), act.value->getHeight(),
...@@ -369,12 +369,12 @@ Status __must_check forward(Argument& act) { ...@@ -369,12 +369,12 @@ Status __must_check forward(Argument& act) {
act.in->copyFrom(*act.value); act.in->copyFrom(*act.value);
act.value->square2(*act.value); act.value->square2(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->squareDerivative(*act.in); act.grad->squareDerivative(*act.in);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(square) END_DEFINE_ACTIVATION(square)
...@@ -385,14 +385,14 @@ END_DEFINE_ACTIVATION(square) ...@@ -385,14 +385,14 @@ END_DEFINE_ACTIVATION(square)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(exponential) BEGIN_DEFINE_ACTIVATION(exponential)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
act.value->exp2(*act.value); act.value->exp2(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->expDerivative(*act.value); act.grad->expDerivative(*act.value);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(exponential) END_DEFINE_ACTIVATION(exponential)
...@@ -403,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential) ...@@ -403,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
* \f] * \f]
*/ */
BEGIN_DEFINE_ACTIVATION(log) BEGIN_DEFINE_ACTIVATION(log)
Status __must_check forward(Argument& act) { Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId); SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in, Matrix::resizeOrCreate(act.in,
act.value->getHeight(), act.value->getHeight(),
...@@ -413,12 +413,12 @@ Status __must_check forward(Argument& act) { ...@@ -413,12 +413,12 @@ Status __must_check forward(Argument& act) {
act.in->copyFrom(*act.value); act.in->copyFrom(*act.value);
act.value->log2(*act.value); act.value->log2(*act.value);
return Status(); return Error();
} }
Status __must_check backward(Argument& act) { Error __must_check backward(Argument& act) {
act.grad->dotDiv(*act.grad, *act.in); act.grad->dotDiv(*act.grad, *act.in);
return Status(); return Error();
} }
END_DEFINE_ACTIVATION(log) END_DEFINE_ACTIVATION(log)
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/utils/Status.h" #include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
* *
* Usually, act is Layer::output_ * Usually, act is Layer::output_
*/ */
virtual Status __must_check forward(Argument& act) = 0; virtual Error __must_check forward(Argument& act) = 0;
/** /**
* @brief Backward propagaion * @brief Backward propagaion
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
* - Before calling backward(), act.grad = dE / dy, where E is the error/cost * - Before calling backward(), act.grad = dE / dy, where E is the error/cost
* - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx) * - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx)
*/ */
virtual Status __must_check backward(Argument& act) = 0; virtual Error __must_check backward(Argument& act) = 0;
virtual const std::string& getName() const = 0; virtual const std::string& getName() const = 0;
}; };
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/utils/Util.h" #include "paddle/utils/Util.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
#include "paddle/utils/Error.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Status.h"
#include "AddtoLayer.h" #include "AddtoLayer.h"
#include "CRFLayer.h" #include "CRFLayer.h"
......
...@@ -34,9 +34,9 @@ namespace paddle { ...@@ -34,9 +34,9 @@ namespace paddle {
* When method return a status, the return must use `__must_check` attribute. * When method return a status, the return must use `__must_check` attribute.
* Example as below. * Example as below.
* @code{cpp} * @code{cpp}
* Status __must_check foo(); * Error __must_check foo();
* *
* Status __must_check bar() { * Error __must_check bar() {
* // do something. * // do something.
* Status s = foo(); // invoke other method return status. * Status s = foo(); // invoke other method return status.
* if (!s.isOK()) return s; * if (!s.isOK()) return s;
...@@ -50,9 +50,9 @@ namespace paddle { ...@@ -50,9 +50,9 @@ namespace paddle {
* Example as below. * Example as below.
* *
* @code{cpp} * @code{cpp}
* Status bar(); * Error bar();
* *
* int foo(Status* status) { * int foo(Error* status) {
* // Do something. * // Do something.
* Status s = bar(); * Status s = bar();
* if (!s.isOK()) { * if (!s.isOK()) {
...@@ -61,15 +61,15 @@ namespace paddle { ...@@ -61,15 +61,15 @@ namespace paddle {
* } * }
* // Do something else. * // Do something else.
* if (someInternalErrorHappend) { * if (someInternalErrorHappend) {
* status->setByPrintf("Some dimension is too large, %d", dimension); * *status = ErrorF("Some dimension is too large, %d", dimension);
* return 0; * return 0;
* } * }
* // End of method. * // End of method.
* return someValue; * return someValue;
* } * }
* *
* Status foobar() { * Error foobar() {
* Status s; * Error s;
* // do something. * // do something.
* foo(&s); * foo(&s);
* if (!s.isOK()) return s; * if (!s.isOK()) return s;
...@@ -81,48 +81,12 @@ namespace paddle { ...@@ -81,48 +81,12 @@ namespace paddle {
* use log(FATAL) or CHECK to make program exit before. When we clean all * use log(FATAL) or CHECK to make program exit before. When we clean all
* log(FATAL) and CHECK in Paddle, 'check' method will be removed. * log(FATAL) and CHECK in Paddle, 'check' method will be removed.
*/ */
class Status final : public std::exception { class Error final : public std::exception {
public: public:
/** /**
* Default Status. OK * Default Status. OK
*/ */
Status() noexcept {} Error() noexcept {}
/**
* @brief Create Status with error message
* @param msg
*/
explicit Status(const std::string& msg) : errMsg_(new std::string(msg)) {}
/**
* @brief set a error message for status.
* @param msg
*/
inline void set(const std::string& msg) noexcept {
errMsg_.reset(new std::string(msg));
}
/**
* @brief set a error message for status. Use C style printf
* @param fmt
*/
template <typename... ARGS>
inline void setByPrintf(const char* fmt, ARGS... args) noexcept {
constexpr size_t kBufferSize = 1024; // 1KB buffer
char buffer[kBufferSize];
snprintf(buffer, kBufferSize, fmt, args...);
errMsg_.reset(new std::string(buffer));
}
/**
* create a error status by C style printf.
*/
template <typename... ARGS>
inline static Status printf(const char* fmt, ARGS... args) noexcept {
Status s;
s.setByPrintf(fmt, args...);
return s;
}
/** /**
* @brief what will return the error message. If status is OK, return nullptr. * @brief what will return the error message. If status is OK, return nullptr.
...@@ -148,8 +112,46 @@ public: ...@@ -148,8 +112,46 @@ public:
*/ */
inline void check() const { CHECK(isOK()) << what(); } inline void check() const { CHECK(isOK()) << what(); }
/**
* friend method to create Error.
*/
template <typename... ARGS>
friend Error __must_check ErrorF(const char* fmt, ARGS... args);
private: private:
std::shared_ptr<std::string> errMsg_; std::shared_ptr<std::string> errMsg_;
}; };
/**
* ErrorF will create an Error by printf syntax.
*
* Specialize this method because clang will give a warning when use printf(fmt)
* without arguments.
*/
template <>
inline Error __must_check ErrorF(const char* msg) {
Error e;
e.errMsg_.reset(new std::string(msg));
return e;
}
/**
* ErrorF will create an Error by printf syntax.
*
* Examples:
* @code{cpp}
* auto err = ErrorF("SomeError");
* auto err2 = ErrorF("SomeErrorWithParameter %f %d", real_val, int_val);
* @endcode{cpp}
*/
template <typename... ARGS>
inline Error __must_check ErrorF(const char* fmt, ARGS... args) {
constexpr size_t kBufferSize = 1024;
char buffer[kBufferSize];
snprintf(buffer, kBufferSize, fmt, args...);
Error e;
e.errMsg_.reset(new std::string(buffer));
return e;
}
} // namespace paddle } // namespace paddle
...@@ -4,7 +4,7 @@ add_simple_unittest(test_CustomStackTrace) ...@@ -4,7 +4,7 @@ add_simple_unittest(test_CustomStackTrace)
add_simple_unittest(test_ThreadBarrier) add_simple_unittest(test_ThreadBarrier)
add_simple_unittest(test_SpinLock) add_simple_unittest(test_SpinLock)
add_simple_unittest(test_SIMDFlags) add_simple_unittest(test_SIMDFlags)
add_simple_unittest(test_Status) add_simple_unittest(test_Error)
add_executable( add_executable(
test_CustomStackTracePrint test_CustomStackTracePrint
......
...@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/utils/Status.h" #include "paddle/utils/Error.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
TEST(Status, testAll) { TEST(Status, testAll) {
paddle::Status status; paddle::Error status;
ASSERT_TRUE(status.isOK()); ASSERT_TRUE(status.isOK());
status.set("I'm the error"); status = paddle::ErrorF("I'm the error");
ASSERT_FALSE(status.isOK()); ASSERT_FALSE(status.isOK());
ASSERT_STREQ("I'm the error", status.what()); ASSERT_STREQ("I'm the error", status.what());
paddle::Status status2("error2"); status = paddle::ErrorF("error2");
ASSERT_FALSE(status2.isOK()); ASSERT_FALSE(status.isOK());
ASSERT_STREQ("error2", status2.what()); ASSERT_STREQ("error2", status.what());
int i = 3; int i = 3;
auto status3 = paddle::Status::printf("error%d", i); auto status3 = paddle::ErrorF("error%d", i);
ASSERT_FALSE(status3.isOK()); ASSERT_FALSE(status3.isOK());
ASSERT_STREQ("error3", status3.what()); ASSERT_STREQ("error3", status3.what());
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册