diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index a472ae07455dd1b10688a4b033358bba70d8f34f..94a621491fa868e9bb9f0ba3a3c8f72e96d24317 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -102,6 +102,19 @@ bool ConvTransposeOpLite::InferShape() const { paddings[i * 2 + 1], param_.strides[i])); } + if (!param_.output_size.empty()) { + for (size_t i = 0; i < param_.output_size.size(); ++i) { + CHECK_LT(param_.output_size[i], output_shape[i + 2] + param_.strides[i]) + << "set output_size error, the output_size should less than " + << output_shape[i + 2] + param_.strides[i] << ", but the value is " + << param_.output_size[i]; + CHECK_GE(param_.output_size[i], output_shape[i + 2]) + << "set output_size error, the output_size should greater than or " + << "equal to " << output_shape[i + 2] << ", but the value is " + << param_.output_size[i]; + output_shape[i + 2] = param_.output_size[i]; + } + } // Set output dims param_.output->Resize(lite::DDim(output_shape)); @@ -157,6 +170,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc, if (op_desc.HasAttr("fuse_relu")) { param_.fuse_relu = op_desc.GetAttr("fuse_relu"); } + if (op_desc.HasAttr("output_size")) { + param_.output_size = op_desc.GetAttr>("output_size"); + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 769c8329f460280303089458e29668c1afa4c5a4..9aba4a1f3e7b96abedb2f4d835f99072bf4b7f4e 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -288,6 +288,8 @@ struct ConvParam { ActivationParam activation_param; // support var_length or not bool var_length{false}; + // only used in conv_transpose. + std::vector output_size; // for int8 WITH_INT8_CONFIG };