提交 c366b879 编写于 作者: C cryoco

fix softmax converter diff when padding dim=1

上级 72746dc2
...@@ -33,26 +33,39 @@ class SoftMaxOpConverter : public OpConverter { ...@@ -33,26 +33,39 @@ class SoftMaxOpConverter : public OpConverter {
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_shape = input1->getDimensions(); nvinfer1::Dims input_shape = input1->getDimensions();
int input_dims = input_shape.nbDims; int input_dims = input_shape.nbDims;
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); int axis = op_desc.HasAttr("axis")
? BOOST_GET_CONST(int, op_desc.GetAttr("axis"))
: -1;
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax, auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax,
*const_cast<nvinfer1::ITensor*>(input1)); *const_cast<nvinfer1::ITensor*>(input1));
uint32_t axes = std::max(0, input_dims - 3); uint32_t axes = std::max(0, input_dims - 3);
// TODO(cryoco): Poor workaround. Fix padded dims problem when TRT layers
// support Nd.
int padded_dims = 0;
int explicit_batch = 0;
if (engine_->with_dynamic_shape()) explicit_batch = 1;
for (int i = input_dims - 1; i > explicit_batch; i--) {
if (input_shape.d[i] == 1) {
padded_dims += 1;
} else {
break;
}
}
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
if (axis == -1) { if (axis == -1) {
axes = input_dims - 1; axes = input_dims - 1 - padded_dims;
} else { } else {
axes = axis; axes = axis;
} }
layer->setAxes(1 << axes);
} else { } else {
if (axis == -1) { if (axis == -1) {
axes = input_dims - 1; axes = input_dims - 1 - padded_dims;
} else { } else {
axes = axis + 1; axes = axis + 1;
} }
layer->setAxes(1 << axes);
} }
layer->setAxes(1 << axes);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode);
......
...@@ -107,8 +107,11 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, ...@@ -107,8 +107,11 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") {
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
std::string padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm")); std::string padding_algorithm = "EXPLICIT";
if (desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (paddings.size() > 2 || if (paddings.size() > 2 ||
(padding_algorithm == "SAME" && op_type != "pool2d")) (padding_algorithm == "SAME" && op_type != "pool2d"))
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册