提交 92662e16 编写于 作者: W Wilber 提交者: GitHub

support conv_transpose output_size attr test=develop (#2749)

* support conv_transpose output_size attr test=develop
上级 61697445
......@@ -102,6 +102,19 @@ bool ConvTransposeOpLite::InferShape() const {
paddings[i * 2 + 1],
param_.strides[i]));
}
if (!param_.output_size.empty()) {
for (size_t i = 0; i < param_.output_size.size(); ++i) {
CHECK_LT(param_.output_size[i], output_shape[i + 2] + param_.strides[i])
<< "set output_size error, the output_size should less than "
<< output_shape[i + 2] + param_.strides[i] << ", but the value is "
<< param_.output_size[i];
CHECK_GE(param_.output_size[i], output_shape[i + 2])
<< "set output_size error, the output_size should greater than or "
<< "equal to " << output_shape[i + 2] << ", but the value is "
<< param_.output_size[i];
output_shape[i + 2] = param_.output_size[i];
}
}
// Set output dims
param_.output->Resize(lite::DDim(output_shape));
......@@ -157,6 +170,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
}
if (op_desc.HasAttr("output_size")) {
param_.output_size = op_desc.GetAttr<std::vector<int>>("output_size");
}
return true;
}
......
......@@ -288,6 +288,8 @@ struct ConvParam {
ActivationParam activation_param;
// support var_length or not
bool var_length{false};
// only used in conv_transpose.
std::vector<int> output_size;
// for int8
WITH_INT8_CONFIG
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册