提交 144b20c1 编写于 作者: N nhzlx

add batch norm op converter

上级 14311bb0
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include <math.h> #include <math.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -23,15 +23,15 @@ class BatchNormOpConverter : public OpConverter { ...@@ -23,15 +23,15 @@ class BatchNormOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
LOG(INFO) LOG(INFO) << "convert a fluid batch norm op to tensorrt batch_norm";
<< "convert a fluid batch norm op to tensorrt batch_norm";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(), 1); // Variance is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(),
1); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* X = engine_->GetITensor(op_desc.Input("X").front());
...@@ -53,7 +53,6 @@ class BatchNormOpConverter : public OpConverter { ...@@ -53,7 +53,6 @@ class BatchNormOpConverter : public OpConverter {
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>(); auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto* Variance_t = Variance_v->GetMutable<framework::LoDTensor>(); auto* Variance_t = Variance_v->GetMutable<framework::LoDTensor>();
// create temp tensor for weights // create temp tensor for weights
framework::LoDTensor bias_tensor; framework::LoDTensor bias_tensor;
framework::LoDTensor mean_tensor; framework::LoDTensor mean_tensor;
...@@ -75,21 +74,23 @@ class BatchNormOpConverter : public OpConverter { ...@@ -75,21 +74,23 @@ class BatchNormOpConverter : public OpConverter {
auto* bias_data = bias_tensor.mutable_data<float>(platform::CPUPlace()); auto* bias_data = bias_tensor.mutable_data<float>(platform::CPUPlace());
auto* mean_data = mean_tensor.mutable_data<float>(platform::CPUPlace()); auto* mean_data = mean_tensor.mutable_data<float>(platform::CPUPlace());
auto* scale_data = scale_tensor.mutable_data<float>(platform::CPUPlace()); auto* scale_data = scale_tensor.mutable_data<float>(platform::CPUPlace());
auto* variance_data = variance_tensor.mutable_data<float>(platform::CPUPlace()); auto* variance_data =
variance_tensor.mutable_data<float>(platform::CPUPlace());
framework::LoDTensor *combile_scale_tensor = new framework::LoDTensor(); std::unique_ptr<framework::LoDTensor> combile_scale_tensor(
framework::LoDTensor *combile_bias_tensor = new framework::LoDTensor(); new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> combile_bias_tensor(
new framework::LoDTensor());
combile_scale_tensor->Resize(scale_tensor.dims()); combile_scale_tensor->Resize(scale_tensor.dims());
combile_bias_tensor->Resize(bias_tensor.dims()); combile_bias_tensor->Resize(bias_tensor.dims());
auto* combile_scale_data = combile_scale_tensor->mutable_data<float>(platform::CPUPlace()); auto* combile_scale_data =
auto* combile_bias_data = combile_bias_tensor->mutable_data<float>(platform::CPUPlace()); combile_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto* combile_bias_data =
engine_->weight_map_[op_desc.Input("Bias").front()] = std::move(std::unique_ptr<framework::Tensor>(combile_bias_tensor)); combile_bias_tensor->mutable_data<float>(platform::CPUPlace());
engine_->weight_map_[op_desc.Input("Scale").front()] = std::move(std::unique_ptr<framework::Tensor>(combile_scale_tensor));
size_t ele_num = combile_scale_tensor->memory_size()/sizeof(float); size_t ele_num = combile_scale_tensor->memory_size() / sizeof(float);
for (size_t i = 0; i < ele_num; i++) { for (size_t i = 0; i < ele_num; i++) {
float scale = scale_data[i]; float scale = scale_data[i];
...@@ -100,22 +101,26 @@ class BatchNormOpConverter : public OpConverter { ...@@ -100,22 +101,26 @@ class BatchNormOpConverter : public OpConverter {
combile_bias_data[i] = bias - mean * combile_scale_data[i]; combile_bias_data[i] = bias - mean * combile_scale_data[i];
} }
TensorRTEngine::Weight scale_weights{
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_scale_data),
static_cast<void*>(combile_scale_data),
combile_scale_tensor->memory_size() / sizeof(float)}; combile_scale_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight shift_weights{
static_cast<void *>(combile_bias_data), nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_bias_data),
combile_bias_tensor->memory_size()/ sizeof(float)}; combile_bias_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0}; 0};
nvinfer1::IScaleLayer* layer =
nvinfer1::IScaleLayer* layer = TRT_ENGINE_ADD_LAYER( TRT_ENGINE_ADD_LAYER(engine_, Scale, *const_cast<nvinfer1::ITensor*>(X),
engine_, Scale, *const_cast<nvinfer1::ITensor*>(X), nvinfer1::ScaleMode::kCHANNEL, nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(),
shift_weights.get(), scale_weights.get(), power_weights.get()); scale_weights.get(), power_weights.get());
auto output_name = op_desc.Output("Y").front(); auto output_name = op_desc.Output("Y").front();
engine_->weight_map[op_desc.Input("Bias").front()] =
std::move(combile_bias_tensor);
engine_->weight_map[op_desc.Input("Scale").front()] =
std::move(combile_scale_tensor);
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { if (test_mode) {
......
...@@ -21,8 +21,9 @@ namespace inference { ...@@ -21,8 +21,9 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(batch_norm_op, test) { TEST(batch_norm_op, test) {
std::unordered_set<std::string> parameters({"batch_norm_scale", std::unordered_set<std::string> parameters(
"batch_norm_bias", "batch_norm_mean", "batch_norm_variance" }); {"batch_norm_scale", "batch_norm_bias", "batch_norm_mean",
"batch_norm_variance"});
framework::Scope scope; framework::Scope scope;
TRTConvertValidation validator(5, parameters, scope, 1 << 15); TRTConvertValidation validator(5, parameters, scope, 1 << 15);
std::vector<int> param_shape{2}; std::vector<int> param_shape{2};
...@@ -38,6 +39,7 @@ TEST(batch_norm_op, test) { ...@@ -38,6 +39,7 @@ TEST(batch_norm_op, test) {
// Prepare Op description // Prepare Op description
framework::OpDesc desc; framework::OpDesc desc;
desc.SetType("batch_norm"); desc.SetType("batch_norm");
desc.SetInput("X", {"batch_norm_X"}); desc.SetInput("X", {"batch_norm_X"});
desc.SetInput("Scale", {"batch_norm_scale"}); desc.SetInput("Scale", {"batch_norm_scale"});
...@@ -57,7 +59,9 @@ TEST(batch_norm_op, test) { ...@@ -57,7 +59,9 @@ TEST(batch_norm_op, test) {
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
std::unordered_set<std::string> neglected_output = {"batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean", "batch_norm_variance"}; std::unordered_set<std::string> neglected_output = {
"batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean",
"batch_norm_variance"};
validator.Execute(3, neglected_output); validator.Execute(3, neglected_output);
} }
......
...@@ -98,11 +98,19 @@ class TRTConvertValidation { ...@@ -98,11 +98,19 @@ class TRTConvertValidation {
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims); engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
} }
void DeclParamVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
// Declare a parameter varaible in the scope. // Declare a parameter varaible in the scope.
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims, true); DeclVar(name, dims, true);
} }
void DeclOutputVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims); DeclVar(name, dims);
} }
...@@ -155,7 +163,8 @@ class TRTConvertValidation { ...@@ -155,7 +163,8 @@ class TRTConvertValidation {
} }
} }
void Execute(int batch_size) { void Execute(int batch_size,
std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op // Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_); PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CUDAPlace place; platform::CUDAPlace place;
...@@ -168,6 +177,7 @@ class TRTConvertValidation { ...@@ -168,6 +177,7 @@ class TRTConvertValidation {
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 3000; const size_t output_space_size = 3000;
for (const auto& output : op_desc_->OutputArgumentNames()) { for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue;
std::vector<float> fluid_out; std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size); std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0], output_space_size); engine_->GetOutputInCPU(output, &trt_out[0], output_space_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册