提交 d3fecf01 编写于 作者: C chenjiaoAngel

fix ttfnet bug. test=develop

上级 7cc5a460
......@@ -169,5 +169,5 @@ REGISTER_LITE_KERNEL(where_index, kHost, kAny, kAny, whereindex, def)
.BindInput("Condition",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.Finalize();
......@@ -30,7 +30,14 @@ bool CompareOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims();
param_.Out->Resize(input_dims);
std::vector<int64_t> new_dims;
if (input_dims.size() == 2 && input_dims[1] == 1) {
new_dims.push_back(input_dims[0]);
param_.Out->Resize(new_dims);
} else {
param_.Out->Resize(input_dims);
}
// param_.Out->Resize(input_dims);
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册