未验证 提交 98ab2433 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] upgrade prelu op (#48528)

* add prelu
上级 c1cadcca
...@@ -31,8 +31,8 @@ class PReluOpConverter : public OpConverter { ...@@ -31,8 +31,8 @@ class PReluOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
size_t input_num = op_desc.Input("X").size();
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
// Get attrs // Get attrs
std::string mode = PADDLE_GET_CONST(std::string, op_desc.GetAttr("mode")); std::string mode = PADDLE_GET_CONST(std::string, op_desc.GetAttr("mode"));
std::string data_format = "NCHW"; std::string data_format = "NCHW";
...@@ -40,50 +40,87 @@ class PReluOpConverter : public OpConverter { ...@@ -40,50 +40,87 @@ class PReluOpConverter : public OpConverter {
data_format = data_format =
PADDLE_GET_CONST(std::string, op_desc.GetAttr("data_format")); PADDLE_GET_CONST(std::string, op_desc.GetAttr("data_format"));
} }
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
auto* alpha_tensor = alpha_var->GetMutable<phi::DenseTensor>();
auto alpha_weight = auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
engine_->GetFp32TrtWeight(op_desc.Input("Alpha")[0], *alpha_tensor); auto* alpha_weight = alpha_var->GetMutable<phi::DenseTensor>();
auto w_dims = alpha_weight->dims();
auto alpha_data =
engine_->GetFp32TrtWeight(op_desc.Input("Alpha")[0], *alpha_weight);
platform::CPUPlace cpu_place; nvinfer1::Dims trt_w_dims;
trt_w_dims.nbDims = w_dims.size();
for (int i = 0; i < trt_w_dims.nbDims; i++) {
trt_w_dims.d[i] = w_dims[i];
}
nvinfer1::ILayer* layer = nullptr; // The `element` or `channel` mode contains the batch using static shape.
if (engine_->with_dynamic_shape()) { if ((mode == "element" || mode == "channel") &&
plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic( !engine_->with_dynamic_shape() &&
static_cast<const float*>(alpha_weight.get().values), (trt_w_dims.nbDims - 1 == input_dims.nbDims)) {
alpha_tensor->numel(), trt_w_dims.nbDims--;
mode, for (int i = 0; i < trt_w_dims.nbDims; i++) {
data_format); trt_w_dims.d[i] = trt_w_dims.d[i + 1];
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
#if IS_TRT_VERSION_GE(7000)
nvinfer1::Dims dims;
dims.nbDims = 0;
// jump batch dim
for (int i = 1; i < alpha_tensor->dims().size(); i++) {
dims.d[dims.nbDims++] = alpha_tensor->dims()[i];
}
for (; dims.nbDims < input->getDimensions().nbDims; dims.nbDims++) {
dims.d[dims.nbDims] = 1;
} }
}
auto alpha_layer = nvinfer1::ITensor* alpha_tensor =
TRT_ENGINE_ADD_LAYER(engine_, Constant, dims, alpha_weight.get()); TRT_ENGINE_ADD_LAYER(engine_, Constant, trt_w_dims, alpha_data.get())
auto alpha_layer_output = alpha_layer->getOutput(0); ->getOutput(0);
layer = TRT_ENGINE_ADD_LAYER( auto alpha_dims = alpha_tensor->getDimensions();
engine_, ParametricReLU, *input, *alpha_layer_output); nvinfer1::ITensor* real_alpha_tensor = alpha_tensor;
#else if (alpha_dims.nbDims != input_dims.nbDims) {
plugin::PReluPlugin* plugin = new plugin::PReluPlugin( auto* reshape_layer =
static_cast<const float*>(alpha_weight.get().values), TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *alpha_tensor);
alpha_tensor->numel(), int c = alpha_dims.d[0];
mode, if (engine_->with_dynamic_shape()) {
data_format); std::vector<nvinfer1::ITensor*> itensors;
layer = engine_->AddPlugin(&input, input_num, plugin); auto* n_tensor = Add1DConstantLayer(1);
#endif auto* c_tensor = Add1DConstantLayer(c);
nvinfer1::ITensor* hw_tensor = nullptr;
nvinfer1::ITensor* shape_tensor = nullptr;
if (input_dims.nbDims - 2 > 0) {
hw_tensor = Add1DConstantLayer(
std::vector<int32_t>(input_dims.nbDims - 2, 1));
}
if (data_format == "NCHW") {
if (hw_tensor != nullptr) {
shape_tensor = Concat(
std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor, hw_tensor});
} else {
shape_tensor =
Concat(std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor});
}
} else {
if (hw_tensor != nullptr) {
shape_tensor = Concat(
std::vector<nvinfer1::ITensor*>{n_tensor, hw_tensor, c_tensor});
} else {
shape_tensor =
Concat(std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor});
}
}
reshape_layer->setInput(1, *shape_tensor);
} else {
nvinfer1::Dims reshape_dim;
reshape_dim.nbDims = input_dims.nbDims;
std::fill(reshape_dim.d, reshape_dim.d + input_dims.nbDims, 1);
if (data_format == "NCHW") {
reshape_dim.d[0] = c;
} else if (data_format == "NHWC") {
reshape_dim.d[input_dims.nbDims - 1] = c;
}
reshape_layer->setReshapeDimensions(reshape_dim);
}
real_alpha_tensor = reshape_layer->getOutput(0);
} }
nvinfer1::ILayer* layer = nullptr;
layer = TRT_ENGINE_ADD_LAYER(
engine_, ParametricReLU, *input, *real_alpha_tensor);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "prelu", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "prelu", {output_name}, test_mode);
} }
......
...@@ -49,22 +49,22 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -49,22 +49,22 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
if dim1 != 0: if dim1 != 0:
shape.append(dim1) shape.append(dim1)
if dim2 != 0: if dim2 != 0:
shape.append(1) shape.append(dim2)
if dim3 != 0: if dim3 != 0:
shape.append(1) shape.append(dim3)
return np.random.random(size=shape).astype(np.float32) return np.random.random(size=shape[1]).astype(np.float32)
elif ( elif (
attrs[0]["mode"] == "channel" attrs[0]["mode"] == "channel"
and attrs[0]["data_format"] == "NHWC" and attrs[0]["data_format"] == "NHWC"
): ):
shape = [1] shape = [1]
if dim1 != 0: if dim1 != 0:
shape.append(1) shape.append(dim1)
if dim2 != 0: if dim2 != 0:
shape.append(1) shape.append(dim2)
if dim3 != 0: if dim3 != 0:
shape.append(dim3) shape.append(dim3)
return np.random.random(size=shape).astype(np.float32) return np.random.random(size=shape[-1]).astype(np.float32)
elif attrs[0]["mode"] == "element": elif attrs[0]["mode"] == "element":
shape = [1] shape = [1]
if dim1 != 0: if dim1 != 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册