提交 82527696 编写于 作者: N nhzlx

1. we delelte mul op, 2.modify fc and action op 3. modify the test inferface

上级 2372daff
# Add TRT tests
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc
SRCS conv2d_op.cc fc_op.cc
DEPS tensorrt_engine mul_op)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
......
......@@ -32,13 +32,13 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
for (int h = 0; h < shape.h(); ++h) {
for (int w = 0; w < shape.w(); ++w) {
odata[h * ostrides.h() + w * ostrides.w()] =
idata[h * ostrides.h() + w * ostrides.w()];
idata[h * istrides.h() + w * istrides.w()];
}
}
}
// indata c * k
// Reorder the data layout from CK to KC.
void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
void ReorderCKtoKC(const TensorRTEngine::Weight& iweights,
TensorRTEngine::Weight* oweights) {
int c = iweights.dims[0];
int k = iweights.dims[1];
......@@ -79,9 +79,8 @@ class FcOpConverter : public OpConverter {
framework::LoDTensor tmp;
tmp.Resize(Y_t->dims());
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), Y_t->data<float>(),
Y_t->dims()[0] * Y_t->dims()[1]);
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), weight_data,
Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float));
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)};
......@@ -93,7 +92,7 @@ class FcOpConverter : public OpConverter {
// The data layout of TRT FC layer's weight is different from fluid's FC,
// need to reorder the elements.
ReorderCKtoKC(tmp_weight, &weight);
ReorderCKtoKC(weight, &tmp_weight);
// Currently, the framework can only handle one fluid op -> one TRT layer,
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
......@@ -103,7 +102,7 @@ class FcOpConverter : public OpConverter {
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected,
*const_cast<nvinfer1::ITensor*>(X),
n_output, weight.get(), bias.get());
n_output, tmp_weight.get(), bias.get());
auto output_name = op_desc.Output("Out").front();
engine_->SetITensor(output_name, layer->getOutput(0));
......@@ -117,5 +116,5 @@ class FcOpConverter : public OpConverter {
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
REGISTER_TRT_OP_CONVERTER(mul, FcOpConverter);
USE_OP(mul);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class MulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
// Both the input1 and input2 do not need transpose.
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
*const_cast<nvinfer1::ITensor*>(input2), false);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(mul);
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
......@@ -23,7 +23,7 @@ namespace tensorrt {
TEST(ReluOpConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
TRTConvertValidation validator(1, parameters, scope, 1000);
validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6));
validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6));
......@@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) {
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(10);
validator.Execute(1);
}
} // namespace tensorrt
......
......@@ -23,11 +23,12 @@ namespace tensorrt {
TEST(fc_op, test) {
std::unordered_set<std::string> parameters({"mul-Y"});
framework::Scope scope;
TRTConvertValidation validator(20, parameters, scope, 1000);
TRTConvertValidation validator(1, parameters, scope, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1));
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2));
validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1));
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(1, 2));
// Prepare Op description
framework::OpDesc desc;
......@@ -38,7 +39,7 @@ TEST(fc_op, test) {
validator.SetOp(*desc.Proto());
validator.Execute(10);
validator.Execute(1);
}
} // namespace tensorrt
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(MulOpConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {"mul-X"});
desc.SetInput("Y", {"mul-Y"});
desc.SetOutput("Out", {"mul-Out"});
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(10);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(mul);
......@@ -39,7 +39,7 @@ namespace tensorrt {
float random(float low, float high) {
static std::random_device rd;
static std::mt19937 mt(rd());
std::uniform_real_distribution<double> dist(1.0, 10.0);
std::uniform_real_distribution<double> dist(low, high);
return dist(mt);
}
......@@ -49,6 +49,7 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
auto* data = tensor->mutable_data<float>(place);
for (size_t i = 0; i < num_elements; i++) {
*(data + i) = random(0., 1.);
}
......@@ -68,7 +69,7 @@ class TRTConvertValidation {
int workspace_size = 1 << 10)
: parameters_(parameters), scope_(scope) {
// create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_.reset(new TensorRTEngine(batch_size, workspace_size, &stream_));
engine_->InitNetwork();
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
......@@ -142,8 +143,7 @@ class TRTConvertValidation {
for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0],
output_space_size * sizeof(float));
engine_->GetOutputInCPU(output, &trt_out[0]);
cudaStreamSynchronize(*engine_->stream());
auto* var = scope_.FindVar(output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册