提交 64a08f84 编写于 作者: N nhzlx

increase the test batch

上级 c13efe02
......@@ -23,7 +23,7 @@ namespace tensorrt {
TEST(elementwise_op, add_weight_test) {
std::unordered_set<std::string> parameters({"elementwise_add-Y"});
framework::Scope scope;
TRTConvertValidation validator(1, parameters, scope, 1 << 15);
TRTConvertValidation validator(10, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
......@@ -41,13 +41,13 @@ TEST(elementwise_op, add_weight_test) {
validator.SetOp(*desc.Proto());
validator.Execute(1);
validator.Execute(8);
}
TEST(elementwise_op, add_tensor_test) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(2, parameters, scope, 1 << 15);
TRTConvertValidation validator(8, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
......@@ -64,7 +64,7 @@ TEST(elementwise_op, add_tensor_test) {
validator.SetOp(*desc.Proto());
validator.Execute(1);
validator.Execute(8);
}
} // namespace tensorrt
......
......@@ -149,7 +149,7 @@ class TRTConvertValidation {
cudaStreamSynchronize(*engine_->stream());
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 2000;
const size_t output_space_size = 3000;
for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册