提交 bc09551e 编写于 作者: Y Yu Yang

Fix unittest

上级 e3f5fdcc
...@@ -69,7 +69,7 @@ TEST(OpKernel, all) { ...@@ -69,7 +69,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
TEST(AddBackwardOp, TestGradOp) { TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>(); auto net = std::make_shared<PlainNet>();
......
...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { ...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false; bool caught = false;
try { try {
src_tensor.data<double>(); src_tensor.data<double>();
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false; bool caught = false;
try { try {
dst_tensor.ShareDataWith<float>(src_tensor); dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册