From 0c0a8a94fd485cf5464268088705d6c91cd71225 Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 10 Jan 2020 14:36:47 +0800 Subject: [PATCH] support conv_transpose output_size attr test=develop (#2749) * support conv_transpose output_size attr test=develop --- lite/operators/conv_transpose_op.cc | 16 ++++++++++++++++ lite/operators/op_params.h | 2 ++ 2 files changed, 18 insertions(+) diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index a472ae0745..94a621491f 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -102,6 +102,19 @@ bool ConvTransposeOpLite::InferShape() const { paddings[i * 2 + 1], param_.strides[i])); } + if (!param_.output_size.empty()) { + for (size_t i = 0; i < param_.output_size.size(); ++i) { + CHECK_LT(param_.output_size[i], output_shape[i + 2] + param_.strides[i]) + << "set output_size error, the output_size should less than " + << output_shape[i + 2] + param_.strides[i] << ", but the value is " + << param_.output_size[i]; + CHECK_GE(param_.output_size[i], output_shape[i + 2]) + << "set output_size error, the output_size should greater than or " + << "equal to " << output_shape[i + 2] << ", but the value is " + << param_.output_size[i]; + output_shape[i + 2] = param_.output_size[i]; + } + } // Set output dims param_.output->Resize(lite::DDim(output_shape)); @@ -157,6 +170,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc, if (op_desc.HasAttr("fuse_relu")) { param_.fuse_relu = op_desc.GetAttr("fuse_relu"); } + if (op_desc.HasAttr("output_size")) { + param_.output_size = op_desc.GetAttr>("output_size"); + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 769c8329f4..9aba4a1f3e 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -288,6 +288,8 @@ struct ConvParam { ActivationParam activation_param; // support var_length or not bool var_length{false}; + // only used in conv_transpose. + std::vector output_size; // for int8 WITH_INT8_CONFIG }; -- GitLab