diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 5ce2aa1fc4d0b275b502af0f97e4a0f83e85de5b..8d9f8c34899d70871c71fac2af2ca5d612ec1d08 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -61,8 +61,8 @@ class ParallelExecutor(object): main_program=test_program, share_vars_from=train_exe) - train_loss, = train_exe.run([loss.name], feed_dict=feed_dict) - test_loss, = test_exe.run([loss.name], feed_dict=feed_dict) + train_loss, = train_exe.run([loss.name], feed=feed_dict) + test_loss, = test_exe.run([loss.name], feed=feed_dict) """ self._places = [] @@ -123,22 +123,23 @@ class ParallelExecutor(object): allow_op_delay) self.scope = scope - def run(self, fetch_list, feed_dict={}): + def run(self, fetch_list, feed={}, feed_dict={}): """ :param fetch_list: A list of variable names that will be fetched. - :param feed_dict: A dict mapping for feed variable name to LoDTensor + :param feed: A dict mapping for feed variable name to LoDTensor or numpy array. :return: fetched value list. """ - if not isinstance(feed_dict, dict): - raise TypeError("feed_dict should be a dict") + feed = feed_dict + if not isinstance(feed, dict): + raise TypeError("feed should be a dict") feed_tensor_dict = {} - for i, feed_name in enumerate(feed_dict): - feed_tensor = feed_dict[feed_name] + for i, feed_name in enumerate(feed): + feed_tensor = feed[feed_name] if not isinstance(feed_tensor, core.LoDTensor): feed_tensor = core.LoDTensor() - feed_tensor.set(feed_dict[feed_name], self._act_places[0]) + feed_tensor.set(feed[feed_name], self._act_places[0]) feed_tensor_dict[feed_name] = feed_tensor fetch_var_name = '@FETCHED_VAR_NAME@'