未验证 提交 15d5f6b9 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

reshape_opteller (#41090)

fix_reshape: for paddle-trt
上级 f3270fc8
...@@ -1479,8 +1479,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1479,8 +1479,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
std::vector<int> shape = std::vector<int> shape =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("shape")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("shape"));
if (shape.size() >= nvinfer1::Dims::MAX_DIMS) return false; if (shape.size() >= nvinfer1::Dims::MAX_DIMS) return false;
if (!with_dynamic_shape && (shape[0] == -1 || shape.size() == 1)) if (!with_dynamic_shape) {
if (shape.size() == 1) {
return false;
}
if (shape[0] == 0) {
return true;
} else {
auto* block = desc.Block();
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1,
std::multiplies<int>());
int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1,
std::multiplies<int>());
if (input_num == shape_num) {
return true;
}
}
return false; return false;
}
} }
if (op_type == "clip") { if (op_type == "clip") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册