test_op_converter.cc 2.2 KB
Newer Older
L
Luo Tao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Z
Zhaolong Xing 已提交
15
#include <gtest/gtest.h>  // NOLINT
16

L
Luo Tao 已提交
17
#include "paddle/fluid/framework/program_desc.h"
18
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
L
Luo Tao 已提交
19 20 21 22 23

namespace paddle {
namespace inference {
namespace tensorrt {

L
Luo Tao 已提交
24
TEST(OpConverter, ConvertBlock) {
L
Luo Tao 已提交
25 26 27
  framework::ProgramDesc prog;
  auto* block = prog.MutableBlock(0);
  auto* conv2d_op = block->AppendOp();
N
nhzlx 已提交
28 29 30

  // init trt engine
  std::unique_ptr<TensorRTEngine> engine_;
Z
Zhaolong Xing 已提交
31
  engine_.reset(new TensorRTEngine(5, 1 << 15));
N
nhzlx 已提交
32
  engine_->InitNetwork();
N
nhzlx 已提交
33

34 35
  engine_->DeclareInput(
      "conv2d-X", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3(2, 5, 5));
N
nhzlx 已提交
36

L
Luo Tao 已提交
37
  conv2d_op->SetType("conv2d");
N
nhzlx 已提交
38 39 40
  conv2d_op->SetInput("Input", {"conv2d-X"});
  conv2d_op->SetInput("Filter", {"conv2d-Y"});
  conv2d_op->SetOutput("Output", {"conv2d-Out"});
L
Luo Tao 已提交
41

N
nhzlx 已提交
42 43 44 45 46 47 48 49 50 51 52
  const std::vector<int> strides({1, 1});
  const std::vector<int> paddings({1, 1});
  const std::vector<int> dilations({1, 1});
  const int groups = 1;

  conv2d_op->SetAttr("strides", strides);
  conv2d_op->SetAttr("paddings", paddings);
  conv2d_op->SetAttr("dilations", dilations);
  conv2d_op->SetAttr("groups", groups);

  // init scope
53
  framework::Scope scope;
N
nhzlx 已提交
54 55 56
  std::vector<int> dim_vec = {3, 2, 3, 3};
  auto* x = scope.Var("conv2d-Y");
  auto* x_tensor = x->GetMutable<framework::LoDTensor>();
57
  x_tensor->Resize(phi::make_ddim(dim_vec));
N
nhzlx 已提交
58
  x_tensor->mutable_data<float>(platform::CUDAPlace(0));
N
nhzlx 已提交
59

60
  OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default);
N
nhzlx 已提交
61
  OpConverter converter;
62 63
  converter.ConvertBlock(
      *block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/);
L
Luo Tao 已提交
64 65 66 67 68
}

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle
69 70

USE_TRT_CONVERTER(conv2d)