提交 830aa12c 编写于 作者: N nhzlx

add elementwise init code

上级 fb204fbf
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
DEPS tensorrt_engine mul_op) DEPS tensorrt_engine mul_op)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -16,3 +16,6 @@ nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc ...@@ -16,3 +16,6 @@ nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class ElementwiseWeightOpConverter : public OpConverter {
public:
ElementwiseWeightOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr);
LOG(INFO) << "convert a fluid elementwise op to tensorrt IScaleLayer";
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims dims_x = X->getDimensions();
PADDLE_ENFORCE(dims_x.nbDims >= 3);
auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
std::vector<int> dims_y = framework::vectorize2int(Y_t->dims());
if (static_cast<int>(dims_y.size()) == dims_x.nbDims + 1) {
if (dims_y[0] == 1) dims_y.erase(dims_y.begin());
}
if (static_cast<int>(dims_y.size()) == 1 && dims_y[0] == dims_x.d[0]) {
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
} else if (static_cast<int>(dims_y.size()) == dims_x.nbDims &&
dims_y[0] == dims_x.d[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
for (int i = 1; i < dims_x.nbDims; i++) {
if (dims_y[i] != dims_x.d[i]) {
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
break;
}
}
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
for (int i = 1; i < dims_x.nbDims; i++) {
if (dims_y[i] != 1)
PADDLE_THROW(
"TensorRT unsupported weight shape for Elementwise op!");
}
}
} else {
PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!");
}
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)};
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::IScaleLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *const_cast<nvinfer1::ITensor*>(X), scale_mode,
shift_weights.get(), scale_weights.get(), power_weights.get());
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
class ElementwiseTensorOpConverter : public OpConverter {
public:
ElementwiseTensorOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr);
LOG(INFO) << "convert a fluid elementwise op to tensorrt IScaleLayer";
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());
nvinfer1::Dims dims_x = X->getDimensions();
nvinfer1::Dims dims_y = Y->getDimensions();
// only support the C * H * W input format
PADDLE_ENFORCE(dims_x.nbDims >= 3);
if (dims_x.nbDims == dims_y.nbDims) {
for (int i = 0; i < dims_x.nbDims; i++) {
if (dims_x.d[i] != dims_y.d[i])
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
}
} else {
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
}
auto op_pair = ops.find(op_type_);
if (op_pair == ops.end()) {
PADDLE_THROW("Wrong elementwise op type!");
}
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
protected:
static const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
ops;
std::string op_type_;
};
const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
ElementwiseTensorOpConverter::ops = {
{"add", nvinfer1::ElementWiseOperation::kSUM},
{"mul", nvinfer1::ElementWiseOperation::kPROD},
{"sub", nvinfer1::ElementWiseOperation::kSUB},
{"div", nvinfer1::ElementWiseOperation::kDIV},
{"min", nvinfer1::ElementWiseOperation::kMIN},
{"pow", nvinfer1::ElementWiseOperation::kPOW},
{"max", nvinfer1::ElementWiseOperation::kMAX},
};
class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorAddOpConverter() { op_type_ = "add"; }
};
class ElementwiseTensorMulOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorMulOpConverter() { op_type_ = "mul"; }
};
class ElementwiseTensorSubOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorSubOpConverter() { op_type_ = "sub"; }
};
class ElementwiseTensorDivOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorDivOpConverter() { op_type_ = "div"; }
};
class ElementwiseTensorMinOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorMinOpConverter() { op_type_ = "min"; }
};
class ElementwiseTensorMaxOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorMaxOpConverter() { op_type_ = "max"; }
};
class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter {
public:
ElementwiseTensorPowOpConverter() { op_type_ = "pow"; }
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, ElementwiseWeightOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor,
ElementwiseTensorAddOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_sub_tensor,
ElementwiseTensorSubOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_div_tensor,
ElementwiseTensorDivOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_mul_tensor,
ElementwiseTensorMulOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_max_tensor,
ElementwiseTensorMaxOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_min_tensor,
ElementwiseTensorMinOpConverter);
REGISTER_TRT_OP_CONVERTER(elementwise_pow_tensor,
ElementwiseTensorPowOpConverter);
...@@ -55,6 +55,31 @@ class OpConverter { ...@@ -55,6 +55,31 @@ class OpConverter {
it = Registry<OpConverter>::Lookup("fc"); it = Registry<OpConverter>::Lookup("fc");
} }
} }
if (op_desc.Type().find("elementwise") != std::string::npos) {
static std::unordered_set<std::string> add_tensor_op_set{
"add", "mul", "sub", "div", "max", "min", "pow"};
// TODO(xingzhaolong): all mul, sub, div
// static std::unordered_set<std::string> add_weight_op_set {"add", "mul",
// "sub", "div"};
static std::unordered_set<std::string> add_weight_op_set{"add"};
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
int op_type_len = op_desc.Type().size();
std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len);
std::string Y = op_desc.Input("Y")[0];
if (parameters.count(Y)) {
PADDLE_ENFORCE(add_weight_op_set.count(op_type) > 0,
"Unsupported elementwise type" + op_type);
it =
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_weight");
} else {
PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0,
"Unsupported elementwise type" + op_type);
it =
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor");
}
}
if (!it) { if (!it) {
it = Registry<OpConverter>::Lookup(op_desc.Type()); it = Registry<OpConverter>::Lookup(op_desc.Type());
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(elementwise_op, add_weight_test) {
std::unordered_set<std::string> parameters({"elementwise_add-Y"});
framework::Scope scope;
TRTConvertValidation validator(1, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("elementwise_add");
desc.SetInput("X", {"elementwise_add-X"});
desc.SetInput("Y", {"elementwise_add-Y"});
desc.SetOutput("Out", {"elementwise_add-Out"});
int axis = 1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(elementwise_op, add_tensor_test) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(1, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("elementwise_add");
desc.SetInput("X", {"elementwise_add-X"});
desc.SetInput("Y", {"elementwise_add-Y"});
desc.SetOutput("Out", {"elementwise_add-Out"});
int axis = 1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(elementwise_add);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册