提交 de6da102 编写于 作者: N nhzlx

add support for global pooling for trt

上级 115a6e64
...@@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter { ...@@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
bool global_pooling = boost::get<bool>(op_desc.GetAttr("global_pooling"));
std::string pool_type = std::string pool_type =
boost::get<std::string>(op_desc.GetAttr("pooling_type")); boost::get<std::string>(op_desc.GetAttr("pooling_type"));
std::vector<int> ksize = std::vector<int> ksize =
...@@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter { ...@@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter {
std::vector<int> paddings = std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
if (global_pooling == true) {
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nv_ksize.d[0] = input_shape.d[nbDims - 2];
nv_ksize.d[1] = input_shape.d[nbDims - 1];
}
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
......
...@@ -40,11 +40,13 @@ TEST(Pool2dOpConverter, main) { ...@@ -40,11 +40,13 @@ TEST(Pool2dOpConverter, main) {
std::vector<int> strides({2, 2}); std::vector<int> strides({2, 2});
std::vector<int> paddings({0, 0}); std::vector<int> paddings({0, 0});
std::string pooling_t = "max"; std::string pooling_t = "max";
bool global_pooling = false;
desc.SetAttr("pooling_type", pooling_t); desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize); desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides); desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings); desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling);
LOG(INFO) << "set OP"; LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册