test_op_converter.cc 2.2 KB
Newer Older
L
Luo Tao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

Z
Zhaolong Xing 已提交
17
#include <gtest/gtest.h>  // NOLINT
18

L
Luo Tao 已提交
19 20 21 22 23 24
#include "paddle/fluid/framework/program_desc.h"

namespace paddle {
namespace inference {
namespace tensorrt {

L
Luo Tao 已提交
25
TEST(OpConverter, ConvertBlock) {
L
Luo Tao 已提交
26 27 28
  framework::ProgramDesc prog;
  auto* block = prog.MutableBlock(0);
  auto* conv2d_op = block->AppendOp();
N
nhzlx 已提交
29 30 31

  // init trt engine
  std::unique_ptr<TensorRTEngine> engine_;
Z
Zhaolong Xing 已提交
32
  engine_.reset(new TensorRTEngine(5, 1 << 15));
N
nhzlx 已提交
33
  engine_->InitNetwork();
N
nhzlx 已提交
34 35 36 37

  engine_->DeclareInput("conv2d-X", nvinfer1::DataType::kFLOAT,
                        nvinfer1::Dims3(2, 5, 5));

L
Luo Tao 已提交
38
  conv2d_op->SetType("conv2d");
N
nhzlx 已提交
39 40 41
  conv2d_op->SetInput("Input", {"conv2d-X"});
  conv2d_op->SetInput("Filter", {"conv2d-Y"});
  conv2d_op->SetOutput("Output", {"conv2d-Out"});
L
Luo Tao 已提交
42

N
nhzlx 已提交
43 44 45 46 47 48 49 50 51 52 53
  const std::vector<int> strides({1, 1});
  const std::vector<int> paddings({1, 1});
  const std::vector<int> dilations({1, 1});
  const int groups = 1;

  conv2d_op->SetAttr("strides", strides);
  conv2d_op->SetAttr("paddings", paddings);
  conv2d_op->SetAttr("dilations", dilations);
  conv2d_op->SetAttr("groups", groups);

  // init scope
54
  framework::Scope scope;
N
nhzlx 已提交
55 56 57
  std::vector<int> dim_vec = {3, 2, 3, 3};
  auto* x = scope.Var("conv2d-Y");
  auto* x_tensor = x->GetMutable<framework::LoDTensor>();
58
  x_tensor->Resize(pten::make_ddim(dim_vec));
N
nhzlx 已提交
59
  x_tensor->mutable_data<float>(platform::CUDAPlace(0));
N
nhzlx 已提交
60 61 62 63

  OpConverter converter;
  converter.ConvertBlock(*block->Proto(), {"conv2d-Y"}, scope,
                         engine_.get() /*TensorRTEngine*/);
L
Luo Tao 已提交
64 65 66 67 68
}

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle
69 70

USE_TRT_CONVERTER(conv2d)