提交 144b20c1 编写于 作者: N nhzlx

add batch norm op converter

上级 14311bb0
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include <math.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
......@@ -23,15 +23,15 @@ class BatchNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
LOG(INFO)
<< "convert a fluid batch norm op to tensorrt batch_norm";
LOG(INFO) << "convert a fluid batch norm op to tensorrt batch_norm";
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(), 1); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(),
1); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
......@@ -53,7 +53,6 @@ class BatchNormOpConverter : public OpConverter {
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto* Variance_t = Variance_v->GetMutable<framework::LoDTensor>();
// create temp tensor for weights
framework::LoDTensor bias_tensor;
framework::LoDTensor mean_tensor;
......@@ -64,9 +63,9 @@ class BatchNormOpConverter : public OpConverter {
mean_tensor.Resize(Mean_t->dims());
scale_tensor.Resize(Scale_t->dims());
variance_tensor.Resize(Variance_t->dims());
platform::CPUPlace cpu_place;
// copy data from gpu to cpu
// copy data from gpu to cpu
TensorCopySync((*Bias_t), cpu_place, &bias_tensor);
TensorCopySync((*Mean_t), cpu_place, &mean_tensor);
TensorCopySync((*Scale_t), cpu_place, &scale_tensor);
......@@ -75,47 +74,53 @@ class BatchNormOpConverter : public OpConverter {
auto* bias_data = bias_tensor.mutable_data<float>(platform::CPUPlace());
auto* mean_data = mean_tensor.mutable_data<float>(platform::CPUPlace());
auto* scale_data = scale_tensor.mutable_data<float>(platform::CPUPlace());
auto* variance_data = variance_tensor.mutable_data<float>(platform::CPUPlace());
framework::LoDTensor *combile_scale_tensor = new framework::LoDTensor();
framework::LoDTensor *combile_bias_tensor = new framework::LoDTensor();
auto* variance_data =
variance_tensor.mutable_data<float>(platform::CPUPlace());
std::unique_ptr<framework::LoDTensor> combile_scale_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> combile_bias_tensor(
new framework::LoDTensor());
combile_scale_tensor->Resize(scale_tensor.dims());
combile_bias_tensor->Resize(bias_tensor.dims());
auto* combile_scale_data = combile_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto* combile_bias_data = combile_bias_tensor->mutable_data<float>(platform::CPUPlace());
auto* combile_scale_data =
combile_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto* combile_bias_data =
combile_bias_tensor->mutable_data<float>(platform::CPUPlace());
size_t ele_num = combile_scale_tensor->memory_size() / sizeof(float);
engine_->weight_map_[op_desc.Input("Bias").front()] = std::move(std::unique_ptr<framework::Tensor>(combile_bias_tensor));
engine_->weight_map_[op_desc.Input("Scale").front()] = std::move(std::unique_ptr<framework::Tensor>(combile_scale_tensor));
size_t ele_num = combile_scale_tensor->memory_size()/sizeof(float);
for (size_t i = 0; i < ele_num; i++) {
float scale = scale_data[i];
float bias = bias_data[i];
float mean = mean_data[i];
float variance = variance_data[i];
combile_scale_data[i] = scale / sqrtf(variance + eps);
combile_bias_data[i] = bias - mean * combile_scale_data[i];
float scale = scale_data[i];
float bias = bias_data[i];
float mean = mean_data[i];
float variance = variance_data[i];
combile_scale_data[i] = scale / sqrtf(variance + eps);
combile_bias_data[i] = bias - mean * combile_scale_data[i];
}
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(combile_scale_data),
combile_scale_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
static_cast<void *>(combile_bias_data),
combile_bias_tensor->memory_size()/ sizeof(float)};
TensorRTEngine::Weight scale_weights{
nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_scale_data),
combile_scale_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight shift_weights{
nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_bias_data),
combile_bias_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::IScaleLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *const_cast<nvinfer1::ITensor*>(X), nvinfer1::ScaleMode::kCHANNEL,
shift_weights.get(), scale_weights.get(), power_weights.get());
nvinfer1::IScaleLayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Scale, *const_cast<nvinfer1::ITensor*>(X),
nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(),
scale_weights.get(), power_weights.get());
auto output_name = op_desc.Output("Y").front();
engine_->weight_map[op_desc.Input("Bias").front()] =
std::move(combile_bias_tensor);
engine_->weight_map[op_desc.Input("Scale").front()] =
std::move(combile_scale_tensor);
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
......
......@@ -21,8 +21,9 @@ namespace inference {
namespace tensorrt {
TEST(batch_norm_op, test) {
std::unordered_set<std::string> parameters({"batch_norm_scale",
"batch_norm_bias", "batch_norm_mean", "batch_norm_variance" });
std::unordered_set<std::string> parameters(
{"batch_norm_scale", "batch_norm_bias", "batch_norm_mean",
"batch_norm_variance"});
framework::Scope scope;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
std::vector<int> param_shape{2};
......@@ -38,6 +39,7 @@ TEST(batch_norm_op, test) {
// Prepare Op description
framework::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"batch_norm_X"});
desc.SetInput("Scale", {"batch_norm_scale"});
......@@ -54,10 +56,12 @@ TEST(batch_norm_op, test) {
bool is_test = true;
desc.SetAttr("epsilon", eps);
desc.SetAttr("is_test", is_test);
validator.SetOp(*desc.Proto());
std::unordered_set<std::string> neglected_output = {"batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean", "batch_norm_variance"};
std::unordered_set<std::string> neglected_output = {
"batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean",
"batch_norm_variance"};
validator.Execute(3, neglected_output);
}
......
......@@ -98,11 +98,19 @@ class TRTConvertValidation {
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
}
void DeclParamVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
// Declare a parameter varaible in the scope.
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims, true);
}
void DeclOutputVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}
......@@ -155,7 +163,8 @@ class TRTConvertValidation {
}
}
void Execute(int batch_size) {
void Execute(int batch_size,
std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CUDAPlace place;
......@@ -168,6 +177,7 @@ class TRTConvertValidation {
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 3000;
for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue;
std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0], output_space_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册