提交 a112fb87 编写于 作者: D dingminghui 提交者: jackzhang235

fix(interp): fix interp output shape error

add concat and scale to paddle_use_bridges.h
上级 df401f07
......@@ -34,67 +34,49 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto out = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims();
CHECK_EQ(x_dims.size(), 4);
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners");
/* int align_mode = */
/* op_info->HasAttr("align_mode") ? op_info->GetAttr<int>("align_mode") :
* 1; */
/* auto interp_method = op_info->GetAttr<std::string>("interp_method"); */
/* if (align_mode == 0 && !align_corners) { */
/* LOG(WARNING) << "[NPU] align_mode = 0 && " */
/* "align_corners = false isn't " */
/* "supported in CNML"; */
/* return FAILED; */
/* } */
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
auto out = scope->FindVar(out_var_name)->GetMutable<Tensor>();
/* int x_h, x_w; */
/* if (interp_method == "bilinear") { */
/* x_h = x_dims[1]; */
/* x_w = x_dims[2]; */
/* auto output_tensor = graph->AddNode( */
/* out_var_name, out->dims().Vectorize(), CNML_TENSOR, CNML_NHWC,
* graph->FPType()); */
/* } */
auto in_h = x_dims[1];
auto in_w = x_dims[2];
// Priority: SizeTensor > OutSize > Scale > scale > out_h/out_w
if (HasInputArg(op_info, scope, "SizeTensor")) {
LOG(ERROR) << "Not support SizeTensor input now";
CHECK(0);
} else {
if (HasInputArg(op_info, scope, "Scale")) {
LOG(ERROR) << "Not support Scale input now";
CHECK(0);
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
if (HasInputArg(op_info, scope, "OutSize")) {
LOG(ERROR) << "Not support OutSize input now";
CHECK(0);
}
}
out->Resize({x_dims[0], out_h, out_w, x_dims[3]});
auto x_h = x_dims[1];
auto x_w = x_dims[2];
auto output_tensor = graph->AddNode(out_var_name,
out->dims().Vectorize(),
CNML_TENSOR,
CNML_NHWC,
graph->FPType());
// Priority: OutSize > scale > out_h/out_w
if (scale > 0) {
out_h = static_cast<int>(x_h * scale);
out_w = static_cast<int>(x_w * scale);
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
// Update out_h and out_w and create out_size node if has OutSize
if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_name = op_info->Input("OutSize").front();
auto out_size = scope->FindVar(out_size_name)->GetMutable<Tensor>();
CHECK_EQ(out_size->numel(), 2);
CHECK(out_size->persistable());
auto out_size_data = out_size->mutable_data<int>();
// Update out_h and out_w if has OutSize
out_h = out_size_data[0];
out_w = out_size_data[1];
}
/* std::cout << "@@@scale: " << scale << "; in| w, h: " << x_w << ":" << x_h
* << "; out| w, h: " << out_w << ":" << out_h << std::endl; */
cnmlBaseOp_t interp_op;
/* if (interp_method == "bilinear") { */
/* cnmlInterpOpParam_t interp_param; */
......
......@@ -26,3 +26,5 @@ USE_SUBGRAPH_BRIDGE(nearest_interp, kMLU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kMLU);
USE_SUBGRAPH_BRIDGE(transpose, kMLU);
USE_SUBGRAPH_BRIDGE(transpose2, kMLU);
USE_SUBGRAPH_BRIDGE(concat, kMLU);
USE_SUBGRAPH_BRIDGE(scale, kMLU);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册