From 1ba0d70b6d7b0f7c68c1a02b321472525b427dd7 Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Tue, 15 Oct 2019 13:02:20 +0800 Subject: [PATCH] [PaddleSlim] [Cherry pick ] make slim reader to support dataloader (#20625) * cherry-pick refine slim reader to support dataloader test=release/1.6 * cherry-pick add test for reader test=release/1.6 * cherry-pick rm checkpoint path test=release/1.6 * cherry-pick fix details test=release/1.6 * cherry-pick fix details test=release/1.6 --- .../fluid/contrib/slim/core/compressor.py | 10 +- .../fluid/contrib/slim/graph/executor.py | 5 +- .../fluid/contrib/slim/tests/test_reader.py | 126 ++++++++++++++++++ 3 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/test_reader.py diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py index 5eb8d970add..21737c76251 100644 --- a/python/paddle/fluid/contrib/slim/core/compressor.py +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -20,7 +20,7 @@ from .... import profiler from .... import scope_guard from ....data_feeder import DataFeeder from ....log_helper import get_logger -from ....reader import PyReader +from ....reader import DataLoaderBase from ..graph import * from .config import ConfigFactory import numpy as np @@ -194,8 +194,8 @@ class Context(object): reader = cached_reader(reader, sampled_rate, self.cache_path, cached_id) - if isinstance(reader, Variable) or (isinstance(reader, PyReader) and - (not reader.iterable)): + if isinstance(reader, Variable) or ( + isinstance(reader, DataLoaderBase) and (not reader.iterable)): reader.start() try: while True: @@ -488,8 +488,8 @@ class Compressor(object): build_strategy=build_strategy) if isinstance(context.train_reader, Variable) or ( - isinstance(context.train_reader, - PyReader) and (not context.train_reader.iterable)): + isinstance(context.train_reader, DataLoaderBase) and + (not context.train_reader.iterable)): context.train_reader.start() try: while True: diff --git a/python/paddle/fluid/contrib/slim/graph/executor.py b/python/paddle/fluid/contrib/slim/graph/executor.py index 74de141b06b..1573d3aa1ce 100644 --- a/python/paddle/fluid/contrib/slim/graph/executor.py +++ b/python/paddle/fluid/contrib/slim/graph/executor.py @@ -42,7 +42,10 @@ class SlimGraphExecutor(object): """ assert isinstance(graph, GraphWrapper) feed = None - if data is not None: + if data is not None and isinstance(data[0], dict): + # return list = False + feed = data + elif data is not None: feeder = DataFeeder( feed_list=list(graph.in_nodes.values()), place=self.place, diff --git a/python/paddle/fluid/contrib/slim/tests/test_reader.py b/python/paddle/fluid/contrib/slim/tests/test_reader.py new file mode 100644 index 00000000000..8c054493d75 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_reader.py @@ -0,0 +1,126 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import shutil +import paddle +import unittest +import paddle.fluid as fluid +from mobilenet import MobileNet +from paddle.fluid.contrib.slim.core import Compressor +from paddle.fluid.contrib.slim.graph import GraphWrapper + + +class TestReader(unittest.TestCase): + """ + Test API of quantization strategy. + """ + + def set_train_reader(self, image, label, place): + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + return train_reader + + def set_val_reader(self, image, label, place): + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + return val_reader + + def set_feed_list(self, image, label): + return [('img', image.name), ('label', label.name)] + + def quan(self, config_file): + if os.path.exists('./checkpoints_quan'): + shutil.rmtree('./checkpoints_quan') + + if not fluid.core.is_compiled_with_cuda(): + return + class_dim = 10 + image_shape = [1, 28, 28] + + train_program = fluid.Program() + startup_program = fluid.Program() + val_program = fluid.Program() + + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + out = MobileNet(name='quan').net(input=image, + class_dim=class_dim) + print("out: {}".format(out.name)) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + val_program = train_program.clone(for_test=False) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + val_reader = self.set_val_reader(image, label, place) + + val_feed_list = self.set_feed_list(image, label) + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', + acc_top5.name)] + + train_reader = self.set_train_reader(image, label, place) + train_feed_list = self.set_feed_list(image, label) + train_fetch_list = [('loss', avg_cost.name)] + + com_pass = Compressor( + place, + fluid.global_scope(), + train_program, + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + train_optimizer=optimizer) + com_pass.config(config_file) + eval_graph = com_pass.run() + + +class TestReader1(TestReader): + def set_train_reader(self, image, label, place): + loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], capacity=16, iterable=True) + loader.set_sample_generator( + paddle.dataset.mnist.train(), batch_size=128, places=place) + return loader + + def set_val_reader(self, image, label, place): + loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], capacity=16, iterable=True) + loader.set_sample_generator( + paddle.dataset.mnist.test(), batch_size=128, places=place) + return loader + + def test_compression(self): + self.quan("./quantization/compress_1.yaml") + + +if __name__ == '__main__': + unittest.main() -- GitLab