From 8f2f977345da8bc8f12f32c8d64aa1a0b41fda66 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 26 Mar 2020 18:31:14 +0800 Subject: [PATCH] support feeding scalar when runing program , test=develop (#23214) * support feed_python_builtin, test=develop * add test, test=develop * support CompiledProgram, test=develop * support fluid.data, test=develop * fix ci problems, test=develop * follow comments, test=develop --- python/paddle/fluid/executor.py | 39 ++--- .../unittests/test_executor_feed_scalar.py | 134 ++++++++++++++++++ 2 files changed, 157 insertions(+), 16 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_executor_feed_scalar.py diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index dd22000e4a..1ada5f74f8 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -366,7 +366,7 @@ def _get_program_cache_key(feed, fetch_list): return str(feed_var_names + fetch_var_names) -def _as_lodtensor(data, place): +def _as_lodtensor(data, place, dtype=None): """ Convert numpy.ndarray to Tensor, its only support Tensor without LoD information. For higher dimensional sequence data, please use LoDTensor directly. @@ -381,6 +381,8 @@ def _as_lodtensor(data, place): Args: data(numpy.ndarray): a instance of array + data(core.Place): the place of created tensor + dtype(core.VarDesc.VarType): the expected data type of created tensor Returns: LoDTensor @@ -391,6 +393,15 @@ def _as_lodtensor(data, place): ndarray to LoDTensor. Please convert data to LoDTensor \ directly before feeding the data.\ ") + + #NOTE(zhiqiu): convert python builtin ,like float and int, to numpy array + if not isinstance(data, np.ndarray): + if np.isscalar(data): + assert dtype is not None, 'dtype should be given when casting python scalar to tensor' + dtype = convert_dtype(dtype) if isinstance( + dtype, core.VarDesc.VarType) else dtype + data = np.array([data]).astype(dtype) + # single tensor case tensor = core.LoDTensor() tensor.set(data, place) @@ -568,9 +579,9 @@ class Executor(object): if op.desc.type() == 'feed': feed_target_name = op.desc.output('Out')[0] cur_feed = feed[feed_target_name] - if not isinstance(cur_feed, core.LoDTensor): - cur_feed = _as_lodtensor(cur_feed, self.place) var = global_block.var(feed_target_name) + if not isinstance(cur_feed, core.LoDTensor): + cur_feed = _as_lodtensor(cur_feed, self.place, var.dtype) check_feed_shape_type(var, cur_feed) idx = op.desc.attr('col') core.set_feed_variable(scope, cur_feed, feed_var_name, idx) @@ -625,16 +636,14 @@ class Executor(object): feed_tensor_dict = dict() for feed_name in feed: feed_tensor = feed[feed_name] + var = global_block.var(feed_name) if need_check_feed else None if not isinstance(feed_tensor, core.LoDTensor): - feed_tensor = core.LoDTensor() # always set to CPU place, since the tensor need to be split # it is fast in CPU - assert isinstance( feed[feed_name], np.ndarray ), \ - "The input({}) should be numpy.array, but not {}.".format( - feed_name, type(feed[feed_name])) - feed_tensor.set(feed[feed_name], core.CPUPlace()) + feed_tensor = _as_lodtensor(feed[feed_name], + core.CPUPlace(), var.dtype + if var else None) if need_check_feed: - var = global_block.var(feed_name) check_feed_shape_type(var, feed_tensor, exe.device_count()) feed_tensor_dict[feed_name] = feed_tensor @@ -648,15 +657,13 @@ class Executor(object): res_dict = dict() for feed_name in each: tensor = each[feed_name] + var = global_block.var( + feed_name) if need_check_feed else None if not isinstance(tensor, core.LoDTensor): - tmp = core.LoDTensor() - assert isinstance(each[feed_name], np.ndarray), \ - "The input({}) should be numpy.array, but not {}.".format( - feed_name, type(each[feed_name])) - tmp.set(tensor, program._places[i]) - tensor = tmp + tensor = _as_lodtensor(each[feed_name], + program._places[i], var.dtype + if var else None) if need_check_feed: - var = global_block.var(feed_name) check_feed_shape_type(var, tensor) res_dict[feed_name] = tensor res.append(res_dict) diff --git a/python/paddle/fluid/tests/unittests/test_executor_feed_scalar.py b/python/paddle/fluid/tests/unittests/test_executor_feed_scalar.py new file mode 100644 index 0000000000..562f3066ef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_executor_feed_scalar.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy +import paddle.fluid.core as core +import paddle.fluid as fluid + + +class TestExecutor(unittest.TestCase): + def net(self): + lr = fluid.data(name="lr", shape=[1], dtype='float32') + x = fluid.data(name="x", shape=[None, 1], dtype='float32') + y = fluid.data(name="y", shape=[None, 1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + opt = fluid.optimizer.Adam(learning_rate=lr) + opt.minimize(avg_cost) + + return lr, avg_cost + + def test_program_feed_float(self): + main_program = fluid.Program() + startup_program = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.scope_guard(scope): + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + lr, cost = self.net() + exe.run(startup_program) + train_data = numpy.array( + [[1.0], [2.0], [3.0], [4.0]]).astype('float32') + y_true = numpy.array( + [[2.0], [4.0], [6.0], [8.0]]).astype('float32') + a = 0.01 + _lr, _ = exe.run(feed={'x': train_data, + 'y': y_true, + 'lr': a}, + fetch_list=[lr, cost], + return_numpy=False) + self.assertEqual(_lr._dtype(), lr.dtype) + self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32) + self.assertEqual(type(a), float) + + def test_program_feed_int(self): + main_program = fluid.Program() + startup_program = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.scope_guard(scope): + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + lr, cost = self.net() + exe.run(startup_program) + train_data = numpy.array( + [[1.0], [2.0], [3.0], [4.0]]).astype('float32') + y_true = numpy.array( + [[2.0], [4.0], [6.0], [8.0]]).astype('float32') + a = 0 + _lr, _ = exe.run(feed={'x': train_data, + 'y': y_true, + 'lr': a}, + fetch_list=[lr, cost], + return_numpy=False) + self.assertEqual(_lr._dtype(), lr.dtype) + self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32) + self.assertEqual(type(a), int) + + def test_compiled_program_feed_scalar(self): + main_program = fluid.Program() + startup_program = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.scope_guard(scope): + lr, cost = self.net() + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + exe.run(startup_program) + compiled_prog = fluid.CompiledProgram( + main_program).with_data_parallel(loss_name=cost.name) + train_data = numpy.array( + [[1.0], [2.0], [3.0], [4.0]]).astype('float32') + y_true = numpy.array( + [[2.0], [4.0], [6.0], [8.0]]).astype('float32') + a = 0.01 + _lr, _ = exe.run(compiled_prog, + feed={'x': train_data, + 'y': y_true, + 'lr': a}, + fetch_list=[lr, cost], + return_numpy=False) + self.assertEqual(_lr._dtype(), lr.dtype) + self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32) + self.assertEqual(type(a), float) + + +class TestAsLodTensor(unittest.TestCase): + def test_as_lodtensor_int32(self): + cpu = fluid.CPUPlace() + tensor = fluid.executor._as_lodtensor(1.0, cpu, + fluid.core.VarDesc.VarType.INT32) + self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.INT32) + + def test_as_lodtensor_fp64(self): + cpu = fluid.CPUPlace() + tensor = fluid.executor._as_lodtensor(1, cpu, + fluid.core.VarDesc.VarType.FP64) + self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64) + + def test_as_lodtensor_error(self): + cpu = fluid.CPUPlace() + self.assertRaises(AssertionError, fluid.executor._as_lodtensor, 1, cpu) + + +if __name__ == '__main__': + unittest.main() -- GitLab