未验证 提交 1a1d596b 编写于 作者: 提交者: GitHub

[inference]add trt act layer convert (#43504)

* add activation layer
上级 ce704ee9
......@@ -49,14 +49,30 @@ class ActivationOpConverter : public OpConverter {
<< "convert a fluid Activation op to tensorrt activation layer whose "
"type is "
<< op_type_;
const nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op_desc.Input("X")[0]);
auto* input_tensor = engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_);
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
op_pair->second);
nvinfer1::IActivationLayer* layer = nullptr;
if (op_type_ == "softplus") {
const float beta = op_desc.HasAttr("beta")
? BOOST_GET_CONST(float, op_desc.GetAttr("beta"))
: 1.0f;
const float threshold =
op_desc.HasAttr("threshold")
? BOOST_GET_CONST(float, op_desc.GetAttr("threshold"))
: 20.0f;
auto* layer_clip = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *input_tensor, nvinfer1::ActivationType::kCLIP);
layer_clip->setAlpha(-3.40282e+038);
layer_clip->setBeta(threshold / beta);
layer = TRT_ENGINE_ADD_LAYER(engine_, Activation,
*layer_clip->getOutput(0), op_pair->second);
layer->setAlpha(1.0f / beta);
layer->setBeta(beta);
} else {
layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input_tensor,
op_pair->second);
}
#if IS_TRT_VERSION_GE(5130)
// max(alpha, min(beta, x))
......@@ -64,6 +80,41 @@ class ActivationOpConverter : public OpConverter {
layer->setAlpha(0.);
layer->setBeta(6.);
}
if (op_type_ == "elu") {
const float alpha = op_desc.HasAttr("alpha")
? BOOST_GET_CONST(float, op_desc.GetAttr("alpha"))
: 1.0f;
layer->setAlpha(alpha);
}
if (op_type_ == "selu") {
const float alpha = op_desc.HasAttr("alpha")
? BOOST_GET_CONST(float, op_desc.GetAttr("alpha"))
: 1.0507009873554804934193349852946;
const float scale = op_desc.HasAttr("scale")
? BOOST_GET_CONST(float, op_desc.GetAttr("scale"))
: 1.6732632423543772848170429916717;
layer->setAlpha(alpha);
layer->setBeta(scale);
}
if (op_type_ == "stanh") {
const float scale_a =
op_desc.HasAttr("scale_a")
? BOOST_GET_CONST(float, op_desc.GetAttr("scale_a"))
: 0.67f;
const float scale_b =
op_desc.HasAttr("scale_b")
? BOOST_GET_CONST(float, op_desc.GetAttr("scale_b"))
: 1.7159f;
layer->setAlpha(scale_b);
layer->setBeta(scale_a);
}
if (op_type_ == "thresholded_relu") {
const float threshold =
op_desc.HasAttr("threshold")
? BOOST_GET_CONST(float, op_desc.GetAttr("threshold"))
: 1.0f;
layer->setAlpha(threshold);
}
#endif
auto output_name = op_desc.Output("Out")[0];
......@@ -83,8 +134,13 @@ const std::unordered_map<std::string, nvinfer1::ActivationType>
{"tanh", nvinfer1::ActivationType::kTANH},
#if IS_TRT_VERSION_GE(5130)
{"relu6", nvinfer1::ActivationType::kCLIP},
{"elu", nvinfer1::ActivationType::kELU},
{"selu", nvinfer1::ActivationType::kSELU},
{"softsign", nvinfer1::ActivationType::kSOFTSIGN},
{"softplus", nvinfer1::ActivationType::kSOFTPLUS},
{"stanh", nvinfer1::ActivationType::kSCALED_TANH},
{"thresholded_relu", nvinfer1::ActivationType::kTHRESHOLDED_RELU}};
#endif
};
class ReluOpConverter : public ActivationOpConverter {
public:
......@@ -101,11 +157,43 @@ class TanhOpConverter : public ActivationOpConverter {
TanhOpConverter() { op_type_ = "tanh"; }
};
#if IS_TRT_VERSION_GE(5130)
class Relu6OpConverter : public ActivationOpConverter {
public:
Relu6OpConverter() { op_type_ = "relu6"; }
};
class EluOpConverter : public ActivationOpConverter {
public:
EluOpConverter() { op_type_ = "elu"; }
};
class SeluOpConverter : public ActivationOpConverter {
public:
SeluOpConverter() { op_type_ = "selu"; }
};
class SoftsignOpConverter : public ActivationOpConverter {
public:
SoftsignOpConverter() { op_type_ = "softsign"; }
};
class SoftplusOpConverter : public ActivationOpConverter {
public:
SoftplusOpConverter() { op_type_ = "softplus"; }
};
class STanhOpConverter : public ActivationOpConverter {
public:
STanhOpConverter() { op_type_ = "stanh"; }
};
class ThreasholdedReluOpConverter : public ActivationOpConverter {
public:
ThreasholdedReluOpConverter() { op_type_ = "thresholded_relu"; }
};
#endif
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -113,4 +201,12 @@ class Relu6OpConverter : public ActivationOpConverter {
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter);
#if IS_TRT_VERSION_GE(5130)
REGISTER_TRT_OP_CONVERTER(relu6, Relu6OpConverter);
REGISTER_TRT_OP_CONVERTER(elu, EluOpConverter);
REGISTER_TRT_OP_CONVERTER(selu, SeluOpConverter);
REGISTER_TRT_OP_CONVERTER(softsign, SoftsignOpConverter);
REGISTER_TRT_OP_CONVERTER(softplus, SoftplusOpConverter);
REGISTER_TRT_OP_CONVERTER(stanh, STanhOpConverter);
REGISTER_TRT_OP_CONVERTER(thresholded_relu, ThreasholdedReluOpConverter);
#endif
......@@ -73,6 +73,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv2d_fusion",
"pool2d",
"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
......@@ -163,6 +169,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv2d_fusion",
"pool2d",
"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
......@@ -261,30 +273,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = {"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf"};
std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid",
"elu", "selu", "softsign",
"softplus", "stanh", "thresholded_relu",
"exp", "log", "sqrt",
"abs", "sin", "cos",
"tan", "tanh", "sinh",
"cosh", "asin", "acos",
"atan", "asinh", "atanh",
"ceil", "floor", "erf"};
if (act_op_list.find(op_type) != act_op_list.end()) {
auto* block = desc.Block();
if (block == nullptr) {
......
......@@ -30,43 +30,61 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
def generate_input1(dims, batch, attrs: List[Dict[str, Any]]):
if dims == 1:
return np.ones([32]).astype(np.float32)
return np.random.random([32]).astype(np.float32)
elif dims == 2:
return np.ones([3, 32]).astype(np.float32)
return np.random.random([3, 32]).astype(np.float32)
elif dims == 3:
return np.ones([3, 32, 32]).astype(np.float32)
return np.random.random([3, 32, 32]).astype(np.float32)
else:
return np.ones([batch, 3, 32, 32]).astype(np.float32)
return np.random.random([batch, 3, 32, 32]).astype(np.float32)
for dims in [1, 2, 3, 4]:
for batch in [1, 4]:
for op_type in ["relu", "sigmoid", "tanh", "relu6"]:
self.dims = dims
dics = [{}]
ops_config = [{
"op_type": op_type,
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data":
TensorConfig(data_gen=partial(
generate_input1, dims, batch, dics))
},
outputs=["output_data"])
yield program_config
for op_type in [
"relu", "sigmoid", "tanh", "relu6", "elu", "selu",
"softsign", "stanh", "thresholded_relu", "softplus"
]:
# few samples to reduce time
#for beta in [-0.2, 0.5, 0.67, 3]:
# for alpha in [-0.2, 0.5, 0.67, 3]:
for beta in [0.67]:
for alpha in [0.67]:
self.dims = dims
dics = [{}]
if op_type == "elu":
dics = [{"alpha": alpha}]
if op_type == "selu":
dics = [{"alpha": beta, "scale": alpha}]
if op_type == "stanh":
dics = [{"scale_a": beta, "scale_b": alpha}]
if op_type == "thresholded_relu":
dics = [{"threshold": alpha}]
if op_type == "softplus":
dics = [{"beta": beta}]
ops_config = [{
"op_type": op_type,
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data":
TensorConfig(data_gen=partial(
generate_input1, dims, batch, dics))
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册