未验证 提交 1a1d596b 编写于 作者: 提交者: GitHub

[inference]add trt act layer convert (#43504)

* add activation layer
上级 ce704ee9
...@@ -49,14 +49,30 @@ class ActivationOpConverter : public OpConverter { ...@@ -49,14 +49,30 @@ class ActivationOpConverter : public OpConverter {
<< "convert a fluid Activation op to tensorrt activation layer whose " << "convert a fluid Activation op to tensorrt activation layer whose "
"type is " "type is "
<< op_type_; << op_type_;
const nvinfer1::ITensor* input_tensor = auto* input_tensor = engine_->GetITensor(op_desc.Input("X")[0]);
engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_); auto op_pair = ops.find(op_type_);
nvinfer1::IActivationLayer* layer = nullptr;
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( if (op_type_ == "softplus") {
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), const float beta = op_desc.HasAttr("beta")
op_pair->second); ? BOOST_GET_CONST(float, op_desc.GetAttr("beta"))
: 1.0f;
const float threshold =
op_desc.HasAttr("threshold")
? BOOST_GET_CONST(float, op_desc.GetAttr("threshold"))
: 20.0f;
auto* layer_clip = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *input_tensor, nvinfer1::ActivationType::kCLIP);
layer_clip->setAlpha(-3.40282e+038);
layer_clip->setBeta(threshold / beta);
layer = TRT_ENGINE_ADD_LAYER(engine_, Activation,
*layer_clip->getOutput(0), op_pair->second);
layer->setAlpha(1.0f / beta);
layer->setBeta(beta);
} else {
layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input_tensor,
op_pair->second);
}
#if IS_TRT_VERSION_GE(5130) #if IS_TRT_VERSION_GE(5130)
// max(alpha, min(beta, x)) // max(alpha, min(beta, x))
...@@ -64,6 +80,41 @@ class ActivationOpConverter : public OpConverter { ...@@ -64,6 +80,41 @@ class ActivationOpConverter : public OpConverter {
layer->setAlpha(0.); layer->setAlpha(0.);
layer->setBeta(6.); layer->setBeta(6.);
} }
if (op_type_ == "elu") {
const float alpha = op_desc.HasAttr("alpha")
? BOOST_GET_CONST(float, op_desc.GetAttr("alpha"))
: 1.0f;
layer->setAlpha(alpha);
}
if (op_type_ == "selu") {
const float alpha = op_desc.HasAttr("alpha")
? BOOST_GET_CONST(float, op_desc.GetAttr("alpha"))
: 1.0507009873554804934193349852946;
const float scale = op_desc.HasAttr("scale")
? BOOST_GET_CONST(float, op_desc.GetAttr("scale"))
: 1.6732632423543772848170429916717;
layer->setAlpha(alpha);
layer->setBeta(scale);
}
if (op_type_ == "stanh") {
const float scale_a =
op_desc.HasAttr("scale_a")
? BOOST_GET_CONST(float, op_desc.GetAttr("scale_a"))
: 0.67f;
const float scale_b =
op_desc.HasAttr("scale_b")
? BOOST_GET_CONST(float, op_desc.GetAttr("scale_b"))
: 1.7159f;
layer->setAlpha(scale_b);
layer->setBeta(scale_a);
}
if (op_type_ == "thresholded_relu") {
const float threshold =
op_desc.HasAttr("threshold")
? BOOST_GET_CONST(float, op_desc.GetAttr("threshold"))
: 1.0f;
layer->setAlpha(threshold);
}
#endif #endif
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
...@@ -83,8 +134,13 @@ const std::unordered_map<std::string, nvinfer1::ActivationType> ...@@ -83,8 +134,13 @@ const std::unordered_map<std::string, nvinfer1::ActivationType>
{"tanh", nvinfer1::ActivationType::kTANH}, {"tanh", nvinfer1::ActivationType::kTANH},
#if IS_TRT_VERSION_GE(5130) #if IS_TRT_VERSION_GE(5130)
{"relu6", nvinfer1::ActivationType::kCLIP}, {"relu6", nvinfer1::ActivationType::kCLIP},
{"elu", nvinfer1::ActivationType::kELU},
{"selu", nvinfer1::ActivationType::kSELU},
{"softsign", nvinfer1::ActivationType::kSOFTSIGN},
{"softplus", nvinfer1::ActivationType::kSOFTPLUS},
{"stanh", nvinfer1::ActivationType::kSCALED_TANH},
{"thresholded_relu", nvinfer1::ActivationType::kTHRESHOLDED_RELU}};
#endif #endif
};
class ReluOpConverter : public ActivationOpConverter { class ReluOpConverter : public ActivationOpConverter {
public: public:
...@@ -101,11 +157,43 @@ class TanhOpConverter : public ActivationOpConverter { ...@@ -101,11 +157,43 @@ class TanhOpConverter : public ActivationOpConverter {
TanhOpConverter() { op_type_ = "tanh"; } TanhOpConverter() { op_type_ = "tanh"; }
}; };
#if IS_TRT_VERSION_GE(5130)
class Relu6OpConverter : public ActivationOpConverter { class Relu6OpConverter : public ActivationOpConverter {
public: public:
Relu6OpConverter() { op_type_ = "relu6"; } Relu6OpConverter() { op_type_ = "relu6"; }
}; };
class EluOpConverter : public ActivationOpConverter {
public:
EluOpConverter() { op_type_ = "elu"; }
};
class SeluOpConverter : public ActivationOpConverter {
public:
SeluOpConverter() { op_type_ = "selu"; }
};
class SoftsignOpConverter : public ActivationOpConverter {
public:
SoftsignOpConverter() { op_type_ = "softsign"; }
};
class SoftplusOpConverter : public ActivationOpConverter {
public:
SoftplusOpConverter() { op_type_ = "softplus"; }
};
class STanhOpConverter : public ActivationOpConverter {
public:
STanhOpConverter() { op_type_ = "stanh"; }
};
class ThreasholdedReluOpConverter : public ActivationOpConverter {
public:
ThreasholdedReluOpConverter() { op_type_ = "thresholded_relu"; }
};
#endif
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -113,4 +201,12 @@ class Relu6OpConverter : public ActivationOpConverter { ...@@ -113,4 +201,12 @@ class Relu6OpConverter : public ActivationOpConverter {
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter); REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter); REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter);
#if IS_TRT_VERSION_GE(5130)
REGISTER_TRT_OP_CONVERTER(relu6, Relu6OpConverter); REGISTER_TRT_OP_CONVERTER(relu6, Relu6OpConverter);
REGISTER_TRT_OP_CONVERTER(elu, EluOpConverter);
REGISTER_TRT_OP_CONVERTER(selu, SeluOpConverter);
REGISTER_TRT_OP_CONVERTER(softsign, SoftsignOpConverter);
REGISTER_TRT_OP_CONVERTER(softplus, SoftplusOpConverter);
REGISTER_TRT_OP_CONVERTER(stanh, STanhOpConverter);
REGISTER_TRT_OP_CONVERTER(thresholded_relu, ThreasholdedReluOpConverter);
#endif
...@@ -73,6 +73,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -73,6 +73,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv2d_fusion", "conv2d_fusion",
"pool2d", "pool2d",
"relu", "relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp", "exp",
"log", "log",
"sqrt", "sqrt",
...@@ -163,6 +169,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -163,6 +169,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv2d_fusion", "conv2d_fusion",
"pool2d", "pool2d",
"relu", "relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp", "exp",
"log", "log",
"sqrt", "sqrt",
...@@ -261,30 +273,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -261,30 +273,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
for (auto& teller : tellers_) { for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = {"relu", std::unordered_set<std::string> act_op_list = {
"elu", "relu", "relu6", "sigmoid",
"selu", "elu", "selu", "softsign",
"softsign", "softplus", "stanh", "thresholded_relu",
"softplus", "exp", "log", "sqrt",
"stanh", "abs", "sin", "cos",
"thresholded_relu", "tan", "tanh", "sinh",
"exp", "cosh", "asin", "acos",
"log", "atan", "asinh", "atanh",
"sqrt", "ceil", "floor", "erf"};
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf"};
if (act_op_list.find(op_type) != act_op_list.end()) { if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
......
...@@ -30,43 +30,61 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -30,43 +30,61 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 1: if dims == 1:
return np.ones([32]).astype(np.float32) return np.random.random([32]).astype(np.float32)
elif dims == 2: elif dims == 2:
return np.ones([3, 32]).astype(np.float32) return np.random.random([3, 32]).astype(np.float32)
elif dims == 3: elif dims == 3:
return np.ones([3, 32, 32]).astype(np.float32) return np.random.random([3, 32, 32]).astype(np.float32)
else: else:
return np.ones([batch, 3, 32, 32]).astype(np.float32) return np.random.random([batch, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]: for dims in [1, 2, 3, 4]:
for batch in [1, 4]: for batch in [1, 4]:
for op_type in ["relu", "sigmoid", "tanh", "relu6"]: for op_type in [
self.dims = dims "relu", "sigmoid", "tanh", "relu6", "elu", "selu",
dics = [{}] "softsign", "stanh", "thresholded_relu", "softplus"
]:
ops_config = [{ # few samples to reduce time
"op_type": op_type, #for beta in [-0.2, 0.5, 0.67, 3]:
"op_inputs": { # for alpha in [-0.2, 0.5, 0.67, 3]:
"X": ["input_data"] for beta in [0.67]:
}, for alpha in [0.67]:
"op_outputs": { self.dims = dims
"Out": ["output_data"] dics = [{}]
}, if op_type == "elu":
"op_attrs": dics[0] dics = [{"alpha": alpha}]
}] if op_type == "selu":
ops = self.generate_op_config(ops_config) dics = [{"alpha": beta, "scale": alpha}]
if op_type == "stanh":
program_config = ProgramConfig( dics = [{"scale_a": beta, "scale_b": alpha}]
ops=ops, if op_type == "thresholded_relu":
weights={}, dics = [{"threshold": alpha}]
inputs={ if op_type == "softplus":
"input_data": dics = [{"beta": beta}]
TensorConfig(data_gen=partial(
generate_input1, dims, batch, dics)) ops_config = [{
}, "op_type": op_type,
outputs=["output_data"]) "op_inputs": {
"X": ["input_data"]
yield program_config },
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data":
TensorConfig(data_gen=partial(
generate_input1, dims, batch, dics))
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs( def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float): self, program_config) -> (paddle_infer.Config, List[int], float):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册