提交 9c15846a 编写于 作者: C chonwhite

attention works

上级 c6d82e0e
...@@ -33,7 +33,7 @@ class Debugger { ...@@ -33,7 +33,7 @@ class Debugger {
void registerOutput(std::string op_type, zynqmp::Tensor* tensor) { void registerOutput(std::string op_type, zynqmp::Tensor* tensor) {
if (op_config[op_type]) { if (op_config[op_type]) {
tensor->saveToFile(op_type, true); // tensor->saveToFile(op_type, true);
} }
} }
......
...@@ -60,6 +60,7 @@ class ConvPE : public PE { ...@@ -60,6 +60,7 @@ class ConvPE : public PE {
if (param_.filter->shape().width() == 1 && if (param_.filter->shape().width() == 1 &&
param_.filter->shape().height() == 1) { // NOLINT param_.filter->shape().height() == 1) { // NOLINT
// use_cpu_ = true;
} }
if (!use_cpu_) { // NOLINT if (!use_cpu_) { // NOLINT
// param_.filter->releaseData(); // param_.filter->releaseData();
......
...@@ -140,10 +140,12 @@ void SSAGraph::Build(const Program &program, ...@@ -140,10 +140,12 @@ void SSAGraph::Build(const Program &program,
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map_[name] = arg_node;
} }
/*
if (var_types.count(name) && !arg_node->arg()->type) { if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy( arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
} }
*/
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node); DirectedLink(arg_node, op_node);
...@@ -153,10 +155,12 @@ void SSAGraph::Build(const Program &program, ...@@ -153,10 +155,12 @@ void SSAGraph::Build(const Program &program,
auto *arg_node = &node_storage_.back(); auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map_[name] = arg_node;
/*
if (var_types.count(name) && !arg_node->arg()->type) { if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy( arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
} }
*/
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
......
...@@ -119,8 +119,10 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -119,8 +119,10 @@ void TypeTargetTransformPass::AddIoCopyInst(
// to.target() // to.target()
// The precision and layout should be equal to from.precision(), // The precision and layout should be equal to from.precision(),
// from.layout() // from.layout()
#ifndef LITE_WITH_FPGA
io_copy_output_arg->AsArg().type = io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
#endif
auto* io_copy_inst = graph->NewInstructNode(); auto* io_copy_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
......
...@@ -107,7 +107,9 @@ class Optimizer { ...@@ -107,7 +107,9 @@ class Optimizer {
"runtime_context_assign_pass", "runtime_context_assign_pass",
"argument_type_display_pass", "argument_type_display_pass",
#ifndef LITE_WITH_FPGA
"memory_optimize_pass", "memory_optimize_pass",
#endif
"npu_subgraph_pass", "npu_subgraph_pass",
"xpu_subgraph_pass"}}; "xpu_subgraph_pass"}};
RunPasses(passes_local); RunPasses(passes_local);
......
...@@ -133,7 +133,7 @@ REGISTER_LITE_KERNEL(fill_constant, ...@@ -133,7 +133,7 @@ REGISTER_LITE_KERNEL(fill_constant,
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList", .BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fill_constant_batch_size_like, REGISTER_LITE_KERNEL(fill_constant_batch_size_like,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册