提交 c366b879 编写于 作者: C cryoco

fix softmax converter diff when padding dim=1

上级 72746dc2
......@@ -33,26 +33,39 @@ class SoftMaxOpConverter : public OpConverter {
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_shape = input1->getDimensions();
int input_dims = input_shape.nbDims;
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
int axis = op_desc.HasAttr("axis")
? BOOST_GET_CONST(int, op_desc.GetAttr("axis"))
: -1;
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax,
*const_cast<nvinfer1::ITensor*>(input1));
uint32_t axes = std::max(0, input_dims - 3);
// TODO(cryoco): Poor workaround. Fix padded dims problem when TRT layers
// support Nd.
int padded_dims = 0;
int explicit_batch = 0;
if (engine_->with_dynamic_shape()) explicit_batch = 1;
for (int i = input_dims - 1; i > explicit_batch; i--) {
if (input_shape.d[i] == 1) {
padded_dims += 1;
} else {
break;
}
}
if (!engine_->with_dynamic_shape()) {
if (axis == -1) {
axes = input_dims - 1;
axes = input_dims - 1 - padded_dims;
} else {
axes = axis;
}
layer->setAxes(1 << axes);
} else {
if (axis == -1) {
axes = input_dims - 1;
axes = input_dims - 1 - padded_dims;
} else {
axes = axis + 1;
}
layer->setAxes(1 << axes);
}
layer->setAxes(1 << axes);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode);
......
......@@ -107,8 +107,11 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") {
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
std::string padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
std::string padding_algorithm = "EXPLICIT";
if (desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (paddings.size() > 2 ||
(padding_algorithm == "SAME" && op_type != "pool2d"))
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册