提交 64a08f84 编写于 作者: N nhzlx

increase the test batch

上级 c13efe02
...@@ -23,7 +23,7 @@ namespace tensorrt { ...@@ -23,7 +23,7 @@ namespace tensorrt {
TEST(elementwise_op, add_weight_test) { TEST(elementwise_op, add_weight_test) {
std::unordered_set<std::string> parameters({"elementwise_add-Y"}); std::unordered_set<std::string> parameters({"elementwise_add-Y"});
framework::Scope scope; framework::Scope scope;
TRTConvertValidation validator(1, parameters, scope, 1 << 15); TRTConvertValidation validator(10, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3)); validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1)); validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); // validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
...@@ -41,13 +41,13 @@ TEST(elementwise_op, add_weight_test) { ...@@ -41,13 +41,13 @@ TEST(elementwise_op, add_weight_test) {
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
validator.Execute(1); validator.Execute(8);
} }
TEST(elementwise_op, add_tensor_test) { TEST(elementwise_op, add_tensor_test) {
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
framework::Scope scope; framework::Scope scope;
TRTConvertValidation validator(2, parameters, scope, 1 << 15); TRTConvertValidation validator(8, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3)); validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3)); validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); // validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
...@@ -64,7 +64,7 @@ TEST(elementwise_op, add_tensor_test) { ...@@ -64,7 +64,7 @@ TEST(elementwise_op, add_tensor_test) {
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
validator.Execute(1); validator.Execute(8);
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -149,7 +149,7 @@ class TRTConvertValidation { ...@@ -149,7 +149,7 @@ class TRTConvertValidation {
cudaStreamSynchronize(*engine_->stream()); cudaStreamSynchronize(*engine_->stream());
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 2000; const size_t output_space_size = 3000;
for (const auto& output : op_desc_->OutputArgumentNames()) { for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out; std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size); std::vector<float> trt_out(output_space_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册