未验证 提交 d84a30b2 编写于 作者: C Chen Weihang 提交者: GitHub

[cherry-pick] append scale to static runner and remove loader place (#24649)

* Append scale for static runner outputs (#24627)

* add scale for static runner outputs, test=develop

* fix import relation, test=develop

* remove len limit, test=develop

* remove imperative data loader place limit, test=develop (#24641)
上级 62047d30
...@@ -24,6 +24,7 @@ from .. import core ...@@ -24,6 +24,7 @@ from .. import core
from .. import framework from .. import framework
from .. import backward from .. import backward
from ..layers import nn
from .base import switch_to_static_graph from .base import switch_to_static_graph
from ... import compat as cpt from ... import compat as cpt
...@@ -359,8 +360,27 @@ class StaticModelRunner(layers.Layer): ...@@ -359,8 +360,27 @@ class StaticModelRunner(layers.Layer):
# NOTE: reverse feed vars # NOTE: reverse feed vars
self._input_names.reverse() self._input_names.reverse()
# Step 4. add scale for outputs
tmp_program = self._build_program_by_desc(program_desc)
self._append_scale_to_output(tmp_program)
return program_desc return program_desc
@switch_to_static_graph
def _append_scale_to_output(self, program):
# 1. append scale & save var
scale_output_vars = []
with framework.program_guard(program):
for i, out in enumerate(self._output_descs):
var = program.global_block().var(out.name())
var = nn.scale(
var, 1., name="static_model_runner/scale_{}".format(i))
scale_output_vars.append(var)
# 2. update output names & descs
for i, var in enumerate(scale_output_vars):
self._output_names[i] = var.name
self._output_descs[i] = var.desc
@switch_to_static_graph @switch_to_static_graph
def _append_backward_desc(self): def _append_backward_desc(self):
assert self._infer_program_desc is not None, "The StaticModelRunner not initialized properly." assert self._infer_program_desc is not None, "The StaticModelRunner not initialized properly."
......
...@@ -18,7 +18,7 @@ import six ...@@ -18,7 +18,7 @@ import six
import numpy as np import numpy as np
import threading import threading
import paddle import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places, _current_expected_place
from .executor import global_scope from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider from .data_feeder import DataFeeder, BatchedTensorProvider
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
...@@ -671,12 +671,12 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -671,12 +671,12 @@ class DygraphGeneratorLoader(DataLoaderBase):
if not iterable: if not iterable:
logging.warning( logging.warning(
"Please NOTE: dygraph can support iterable mode only. Change to iterable mode." "Please NOTE: imperative mode can support iterable mode only. Change to iterable mode."
) )
self._iterable = True self._iterable = True
if not return_list: if not return_list:
logging.warning( logging.warning(
"Please NOTE: dygraph can support return as list only. Change to return as list." "Please NOTE: imperative mode can support return as list only. Change to return as list."
) )
self._return_list = True self._return_list = True
...@@ -941,10 +941,11 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -941,10 +941,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
def set_batch_generator(self, reader, places=None): def set_batch_generator(self, reader, places=None):
self._batch_reader = reader self._batch_reader = reader
assert places is not None, "Places cannot be None when DataLoader is iterable" if places is None:
places = _current_expected_place()
self._places = _convert_places(places) self._places = _convert_places(places)
assert len(self._places) == 1, \ assert len(self._places) == 1, \
"Number of places must be 1 in dygraph mode" "Number of places must be 1 in imperative mode"
return self return self
......
...@@ -41,6 +41,14 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -41,6 +41,14 @@ class TestDygraphDataLoader(unittest.TestCase):
self.epoch_num = 1 self.epoch_num = 1
self.capacity = 5 self.capacity = 5
def iter_loader_data(self, loader):
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_single_process_loader(self): def test_single_process_loader(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
...@@ -49,12 +57,7 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -49,12 +57,7 @@ class TestDygraphDataLoader(unittest.TestCase):
sample_generator_creator(self.batch_size, self.batch_num), sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size, batch_size=self.batch_size,
places=fluid.CPUPlace()) places=fluid.CPUPlace())
for _ in range(self.epoch_num): self.iter_loader_data(loader)
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_multi_process_loader(self): def test_multi_process_loader(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -64,12 +67,15 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -64,12 +67,15 @@ class TestDygraphDataLoader(unittest.TestCase):
sample_generator_creator(self.batch_size, self.batch_num), sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size, batch_size=self.batch_size,
places=fluid.CPUPlace()) places=fluid.CPUPlace())
for _ in range(self.epoch_num): self.iter_loader_data(loader)
for image, label in loader():
relu = fluid.layers.relu(image) def test_generator_no_places(self):
self.assertEqual(image.shape, [self.batch_size, 784]) with fluid.dygraph.guard():
self.assertEqual(label.shape, [self.batch_size, 1]) loader = fluid.io.DataLoader.from_generator(capacity=self.capacity)
self.assertEqual(relu.shape, [self.batch_size, 784]) loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size)
self.iter_loader_data(loader)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册