未验证 提交 c7c5635e 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] The Graph uses OpConverterType for op converter (#53214)

* add ```converter_type``` for op converter
上级 46951224
......@@ -56,8 +56,9 @@ class OpConverter {
OpConverter* it{nullptr};
auto op_converter_type_map = OpTeller::Global().GetOpConverterTypeMap();
switch (op_converter_type_map.at(op_desc.Type())) {
auto converter_type = static_cast<OpConverterType>(
PADDLE_GET_CONST(int, op_desc.GetAttr("converter_type")));
switch (converter_type) {
case OpConverterType::Default:
if (op_desc.Type().find("elementwise") != std::string::npos) {
static std::unordered_set<std::string> add_tensor_op_set{
......
......@@ -109,7 +109,7 @@ TEST(CustomPluginCreater, StaticShapePlugin) {
framework::OpDesc custom_op(*op_desc, nullptr);
CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true);
OpTeller::Global().SetOpConverterType("custom_op",
OpTeller::Global().SetOpConverterType(&custom_op,
OpConverterType::CustomPluginCreater);
OpConverter converter;
......@@ -196,7 +196,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) {
framework::OpDesc custom_op(*op_desc, nullptr);
CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true);
OpTeller::Global().SetOpConverterType("custom_op",
OpTeller::Global().SetOpConverterType(&custom_op,
OpConverterType::CustomPluginCreater);
OpConverter converter;
......
......@@ -57,7 +57,7 @@ TEST(OpConverter, ConvertBlock) {
x_tensor->Resize(phi::make_ddim(dim_vec));
x_tensor->mutable_data<float>(platform::CUDAPlace(0));
OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default);
OpTeller::Global().SetOpConverterType(conv2d_op, OpConverterType::Default);
OpConverter converter;
converter.ConvertBlock(
*block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/);
......
......@@ -3080,17 +3080,17 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
auto& default_teller = GetDefaultTeller();
if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::Default);
SetOpConverterType(node->Op(), OpConverterType::Default);
return true;
}
auto& generic_plugin_teller = GetGenericPluginTeller();
if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::GenericPluginCreater);
SetOpConverterType(node->Op(), OpConverterType::GenericPluginCreater);
return true;
}
auto& custom_plugin_teller = GetCustomPluginTeller();
if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::CustomPluginCreater);
SetOpConverterType(node->Op(), OpConverterType::CustomPluginCreater);
return true;
}
return false;
......
......@@ -82,12 +82,8 @@ class OpTeller {
std::unique_ptr<Teller>& GetCustomPluginTeller() { return tellers_.at(2); }
void SetOpConverterType(std::string name, OpConverterType type) {
op_converter_type_map_[name] = type;
}
const std::map<std::string, OpConverterType>& GetOpConverterTypeMap() const {
return op_converter_type_map_;
void SetOpConverterType(framework::OpDesc* op_desc, OpConverterType type) {
op_desc->SetAttr("converter_type", static_cast<int>(type));
}
private:
......@@ -95,7 +91,6 @@ class OpTeller {
private:
std::vector<std::unique_ptr<Teller>> tellers_;
std::map<std::string, OpConverterType> op_converter_type_map_;
};
} // namespace tensorrt
......
......@@ -94,6 +94,11 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
"Out", std::vector<std::string>({"z0"})); // 2 x 4 x 4 x 4
elementwise_add1->SetAttr("axis", static_cast<int32_t>(0));
inference::tensorrt::OpTeller::Global().SetOpConverterType(
elementwise_add0, inference::tensorrt::OpConverterType::Default);
inference::tensorrt::OpTeller::Global().SetOpConverterType(
elementwise_add1, inference::tensorrt::OpConverterType::Default);
// Set inputs' variable shape in BlockDesc
// the batch size is 2, so the dims of 'x' is {2, 4}
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4, 4, 4}));
......@@ -170,8 +175,6 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
// Execute them.
LOG(INFO) << "engine_op run";
inference::tensorrt::OpTeller::Global().SetOpConverterType(
"elementwise_add", inference::tensorrt::OpConverterType::Default);
engine_op->Run(scope, place);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册