提交 ec790e10 编写于 作者: Y Yu Yang

Rename Status => Error.

* Also make ErrorF as a global method.
上级 8605544c
......@@ -69,13 +69,13 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
class IdentityActivation : public ActivationFunction {
public:
static const std::string name;
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
(void)act;
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
(void)act;
return Status();
return Error();
}
const std::string& getName() const { return name; }
};
......@@ -92,13 +92,13 @@ static InitFunction __reg_activation__identity([] {
* \f]
*/
BEGIN_DEFINE_ACTIVATION(sigmoid)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->sigmoid(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->sigmoidDerivative(*act.value);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(sigmoid)
......@@ -115,12 +115,12 @@ MatrixPtr sftMaxDot_;
MatrixPtr one_;
public:
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->softmax(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
MatrixPtr outputV = act.value;
MatrixPtr outputG = act.grad;
......@@ -152,7 +152,7 @@ Status __must_check backward(Argument& act) {
act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
}
return Status();
return Error();
}
END_DEFINE_ACTIVATION(softmax)
......@@ -167,9 +167,9 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
Argument argument_;
public:
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
if (act.value->getWidth() != 1UL) {
return Status(
return ErrorF(
"Input width for each timestep of sequence softmax should be 1");
}
......@@ -188,12 +188,12 @@ Status __must_check forward(Argument& act) {
auto starts = act.sequenceStartPositions->getVector(useGpu(act.deviceId));
act.value->sequenceSoftmax(*act.value, *starts);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
if (act.value->getWidth() != 1UL) {
return Status(
return ErrorF(
"Input width for each timestep of sequence softmax should be 1");
}
......@@ -207,10 +207,10 @@ Status __must_check backward(Argument& act) {
argument_.value->setData(act.value->getData() + offset, 1UL, size);
argument_.grad->setData(act.grad->getData() + offset, 1UL, size);
Status status = softmax_.backward(argument_);
Error status = softmax_.backward(argument_);
if (!status.isOK()) return status;
}
return Status();
return Error();
}
END_DEFINE_ACTIVATION(sequence_softmax)
......@@ -225,14 +225,14 @@ END_DEFINE_ACTIVATION(sequence_softmax)
* 0 otherwise.
*/
BEGIN_DEFINE_ACTIVATION(relu)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->relu(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->reluDerivative(*act.value);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(relu)
......@@ -250,14 +250,14 @@ END_DEFINE_ACTIVATION(relu)
* TODO(yuyang18): Remove magic number 24 or make it configuable.
*/
BEGIN_DEFINE_ACTIVATION(brelu)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->brelu(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->breluDerivative(*act.value);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(brelu)
......@@ -268,14 +268,14 @@ END_DEFINE_ACTIVATION(brelu)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(tanh)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->tanh(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->tanhDerivative(*act.value);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(tanh)
......@@ -291,14 +291,14 @@ real a, b;
public:
ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {}
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->scaledTanh(*act.value, a, b);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->scaledTanhDerivative(*act.value, a, b);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(stanh)
......@@ -309,14 +309,14 @@ END_DEFINE_ACTIVATION(stanh)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(softrelu)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->softrelu(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->softreluDerivative(*act.value);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(softrelu)
......@@ -333,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
* 0 if z=0
*/
BEGIN_DEFINE_ACTIVATION(abs)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in,
act.value->getHeight(),
......@@ -343,12 +343,12 @@ Status __must_check forward(Argument& act) {
act.in->copyFrom(*act.value);
act.value->abs2(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->absDerivative(*act.in);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(abs)
......@@ -359,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(square)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in,
act.value->getHeight(),
......@@ -369,12 +369,12 @@ Status __must_check forward(Argument& act) {
act.in->copyFrom(*act.value);
act.value->square2(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->squareDerivative(*act.in);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(square)
......@@ -385,14 +385,14 @@ END_DEFINE_ACTIVATION(square)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(exponential)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
act.value->exp2(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->expDerivative(*act.value);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(exponential)
......@@ -403,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(log)
Status __must_check forward(Argument& act) {
Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in,
act.value->getHeight(),
......@@ -413,12 +413,12 @@ Status __must_check forward(Argument& act) {
act.in->copyFrom(*act.value);
act.value->log2(*act.value);
return Status();
return Error();
}
Status __must_check backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->dotDiv(*act.grad, *act.in);
return Status();
return Error();
}
END_DEFINE_ACTIVATION(log)
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/utils/Status.h"
#include "paddle/utils/Error.h"
namespace paddle {
......@@ -49,7 +49,7 @@ public:
*
* Usually, act is Layer::output_
*/
virtual Status __must_check forward(Argument& act) = 0;
virtual Error __must_check forward(Argument& act) = 0;
/**
* @brief Backward propagaion
......@@ -58,7 +58,7 @@ public:
* - Before calling backward(), act.grad = dE / dy, where E is the error/cost
* - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx)
*/
virtual Status __must_check backward(Argument& act) = 0;
virtual Error __must_check backward(Argument& act) = 0;
virtual const std::string& getName() const = 0;
};
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/utils/Util.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/utils/Error.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Status.h"
#include "AddtoLayer.h"
#include "CRFLayer.h"
......
......@@ -34,9 +34,9 @@ namespace paddle {
* When method return a status, the return must use `__must_check` attribute.
* Example as below.
* @code{cpp}
* Status __must_check foo();
* Error __must_check foo();
*
* Status __must_check bar() {
* Error __must_check bar() {
* // do something.
* Status s = foo(); // invoke other method return status.
* if (!s.isOK()) return s;
......@@ -50,9 +50,9 @@ namespace paddle {
* Example as below.
*
* @code{cpp}
* Status bar();
* Error bar();
*
* int foo(Status* status) {
* int foo(Error* status) {
* // Do something.
* Status s = bar();
* if (!s.isOK()) {
......@@ -61,15 +61,15 @@ namespace paddle {
* }
* // Do something else.
* if (someInternalErrorHappend) {
* status->setByPrintf("Some dimension is too large, %d", dimension);
* *status = ErrorF("Some dimension is too large, %d", dimension);
* return 0;
* }
* // End of method.
* return someValue;
* }
*
* Status foobar() {
* Status s;
* Error foobar() {
* Error s;
* // do something.
* foo(&s);
* if (!s.isOK()) return s;
......@@ -81,48 +81,12 @@ namespace paddle {
* use log(FATAL) or CHECK to make program exit before. When we clean all
* log(FATAL) and CHECK in Paddle, 'check' method will be removed.
*/
class Status final : public std::exception {
class Error final : public std::exception {
public:
/**
* Default Status. OK
*/
Status() noexcept {}
/**
* @brief Create Status with error message
* @param msg
*/
explicit Status(const std::string& msg) : errMsg_(new std::string(msg)) {}
/**
* @brief set a error message for status.
* @param msg
*/
inline void set(const std::string& msg) noexcept {
errMsg_.reset(new std::string(msg));
}
/**
* @brief set a error message for status. Use C style printf
* @param fmt
*/
template <typename... ARGS>
inline void setByPrintf(const char* fmt, ARGS... args) noexcept {
constexpr size_t kBufferSize = 1024; // 1KB buffer
char buffer[kBufferSize];
snprintf(buffer, kBufferSize, fmt, args...);
errMsg_.reset(new std::string(buffer));
}
/**
* create a error status by C style printf.
*/
template <typename... ARGS>
inline static Status printf(const char* fmt, ARGS... args) noexcept {
Status s;
s.setByPrintf(fmt, args...);
return s;
}
Error() noexcept {}
/**
* @brief what will return the error message. If status is OK, return nullptr.
......@@ -148,8 +112,46 @@ public:
*/
inline void check() const { CHECK(isOK()) << what(); }
/**
* friend method to create Error.
*/
template <typename... ARGS>
friend Error __must_check ErrorF(const char* fmt, ARGS... args);
private:
std::shared_ptr<std::string> errMsg_;
};
/**
* ErrorF will create an Error by printf syntax.
*
* Specialize this method because clang will give a warning when use printf(fmt)
* without arguments.
*/
template <>
inline Error __must_check ErrorF(const char* msg) {
Error e;
e.errMsg_.reset(new std::string(msg));
return e;
}
/**
* ErrorF will create an Error by printf syntax.
*
* Examples:
* @code{cpp}
* auto err = ErrorF("SomeError");
* auto err2 = ErrorF("SomeErrorWithParameter %f %d", real_val, int_val);
* @endcode{cpp}
*/
template <typename... ARGS>
inline Error __must_check ErrorF(const char* fmt, ARGS... args) {
constexpr size_t kBufferSize = 1024;
char buffer[kBufferSize];
snprintf(buffer, kBufferSize, fmt, args...);
Error e;
e.errMsg_.reset(new std::string(buffer));
return e;
}
} // namespace paddle
......@@ -4,7 +4,7 @@ add_simple_unittest(test_CustomStackTrace)
add_simple_unittest(test_ThreadBarrier)
add_simple_unittest(test_SpinLock)
add_simple_unittest(test_SIMDFlags)
add_simple_unittest(test_Status)
add_simple_unittest(test_Error)
add_executable(
test_CustomStackTracePrint
......
......@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/utils/Status.h"
#include "paddle/utils/Error.h"
#include <gtest/gtest.h>
TEST(Status, testAll) {
paddle::Status status;
paddle::Error status;
ASSERT_TRUE(status.isOK());
status.set("I'm the error");
status = paddle::ErrorF("I'm the error");
ASSERT_FALSE(status.isOK());
ASSERT_STREQ("I'm the error", status.what());
paddle::Status status2("error2");
ASSERT_FALSE(status2.isOK());
ASSERT_STREQ("error2", status2.what());
status = paddle::ErrorF("error2");
ASSERT_FALSE(status.isOK());
ASSERT_STREQ("error2", status.what());
int i = 3;
auto status3 = paddle::Status::printf("error%d", i);
auto status3 = paddle::ErrorF("error%d", i);
ASSERT_FALSE(status3.isOK());
ASSERT_STREQ("error3", status3.what());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册