提交 eb1d5c50 编写于 作者: Z zhaoying 提交者: jackzhang235

(bugfix): add cast op before argmax op in argmax converter, so we can expect int32 output

上级 59bbf075
......@@ -40,22 +40,55 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_dims = output->dims().Vectorize();
int axis = op_info->GetAttr<int64_t>("axis");
if (axis < 0) {
axis = axis + x_dims.size();
}
cnmlDimension_t argmax_mode = static_cast<cnmlDimension_t>(axis);
auto mlu_output_dim = x->dims().Vectorize();
// shape is NCHW, layout is NHWC
mlu_output_dim[axis] = 1;
auto input_tensor = graph->GetNode(x_var_name);
// if use_fp16 and axis is not c, cast input datatype from fp16 to fp32, so
// output datatype is int32
bool cast_to_fp32 =
graph->FPType() == CNML_DATA_FLOAT16 && argmax_mode != CNML_DIM_C;
cnmlBaseOp_t cast_op{nullptr};
std::shared_ptr<MLUTensor> fp32_input_tensor;
if (cast_to_fp32) {
fp32_input_tensor = graph->AddNode(x_var_name + ".fp32",
x_dims,
CNML_TENSOR,
CNML_NCHW,
CNML_DATA_FLOAT32);
cnmlCreateCastOp(&cast_op,
CNML_CAST_FLOAT16_TO_FLOAT32,
input_tensor->mlu_tensor(),
fp32_input_tensor->mlu_tensor());
}
auto output_tensor = graph->AddNode(
out_var_name, mlu_output_dim, CNML_TENSOR, CNML_NCHW, graph->FPType());
out_var_name, mlu_output_dim, CNML_TENSOR, CNML_NCHW, CNML_DATA_INT32);
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t argmax_op{nullptr};
// ======================= DEBUG INFO =====================
VLOG(6) << "x_var_name: " << x_var_name;
VLOG(6) << "out_var_name: " << out_var_name;
VLOG(6) << "x dims: " << x->dims();
VLOG(6) << "output dims: " << output->dims();
VLOG(6) << "axis: " << axis;
VLOG(6) << "cast_to_fp32: " << cast_to_fp32;
cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
// ======================= DEBUG END =====================
CNML_CALL(cnmlCreateArgmaxOp(&argmax_op,
argmax_mode,
input_tensor->mlu_tensor(),
cast_to_fp32 ? fp32_input_tensor->mlu_tensor()
: input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
if (cast_to_fp32) {
graph->FuseOp(cast_op);
}
graph->FuseOp(argmax_op);
CNML_CALL(cnmlDestroyBaseOp(&argmax_op));
return SUCCESS;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册