未验证 提交 fa846da5 编写于 作者: C Chen Weihang 提交者: GitHub

Append scale for static runner outputs (#24627)

* add scale for static runner outputs, test=develop

* fix import relation, test=develop

* remove len limit, test=develop
上级 0ec3a42e
...@@ -24,6 +24,7 @@ from .. import core ...@@ -24,6 +24,7 @@ from .. import core
from .. import framework from .. import framework
from .. import backward from .. import backward
from ..layers import nn
from .base import switch_to_static_graph from .base import switch_to_static_graph
from ... import compat as cpt from ... import compat as cpt
...@@ -359,8 +360,27 @@ class StaticModelRunner(layers.Layer): ...@@ -359,8 +360,27 @@ class StaticModelRunner(layers.Layer):
# NOTE: reverse feed vars # NOTE: reverse feed vars
self._input_names.reverse() self._input_names.reverse()
# Step 4. add scale for outputs
tmp_program = self._build_program_by_desc(program_desc)
self._append_scale_to_output(tmp_program)
return program_desc return program_desc
@switch_to_static_graph
def _append_scale_to_output(self, program):
# 1. append scale & save var
scale_output_vars = []
with framework.program_guard(program):
for i, out in enumerate(self._output_descs):
var = program.global_block().var(out.name())
var = nn.scale(
var, 1., name="static_model_runner/scale_{}".format(i))
scale_output_vars.append(var)
# 2. update output names & descs
for i, var in enumerate(scale_output_vars):
self._output_names[i] = var.name
self._output_descs[i] = var.desc
@switch_to_static_graph @switch_to_static_graph
def _append_backward_desc(self): def _append_backward_desc(self):
assert self._infer_program_desc is not None, "The StaticModelRunner not initialized properly." assert self._infer_program_desc is not None, "The StaticModelRunner not initialized properly."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册