提交 af3c5d92 编写于 作者: myq406450149's avatar myq406450149

resize assign_value output tensor. test=develop

上级 40006e19
......@@ -44,12 +44,15 @@ class Eliminator : public FuseBase {
auto* scope = assign_node->stmt()->op()->scope();
auto* op_info = assign_node->stmt()->op()->op_info();
auto shape = op_info->GetAttr<std::vector<int>>("shape");
std::vector<int64_t> out_shape;
for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]);
auto dtype = op_info->GetAttr<int>("dtype");
auto fp32_values = op_info->GetAttr<std::vector<float>>("fp32_values");
auto int32_values = op_info->GetAttr<std::vector<int>>("int32_values");
auto* out = matched.at("out");
auto* out_tensor = scope->FindVar(out->arg()->name)
->GetMutable<lite::Tensor>();
auto* out_tensor =
scope->FindVar(out->arg()->name)->GetMutable<lite::Tensor>();
out_tensor->Resize(out_shape);
if (dtype == static_cast<int>(lite::core::FluidType::INT32)) {
TensorFromVector(int32_values, out_tensor);
} else if (dtype == static_cast<int>(lite::core::FluidType::FP32)) {
......@@ -58,7 +61,6 @@ class Eliminator : public FuseBase {
LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype;
}
GraphSafeRemoveNodes(graph, {matched.at("assign_value")});
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册