From 42f2dd4041c4bc584194cb55470190f8233be70f Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 19 Oct 2017 17:03:49 -0700 Subject: [PATCH] Unify `set_feed_variable` to one method (#4949) --- paddle/framework/feed_fetch_method.h | 1 - paddle/pybind/pybind.cc | 5 +---- python/paddle/v2/framework/executor.py | 3 +-- python/paddle/v2/framework/tests/test_feed_fetch_method.py | 2 +- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h index d58736dcb1c..3ef70043d6a 100644 --- a/paddle/framework/feed_fetch_method.h +++ b/paddle/framework/feed_fetch_method.h @@ -21,7 +21,6 @@ limitations under the License. */ namespace paddle { namespace framework { -template void SetFeedVariable(const LoDTensor& input, const std::string& var_name, size_t index) { // If var_name Variable is not found in GlobalScope, a new variable will diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3455c82e675..94c9706f794 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -460,10 +460,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); - //! FIXME: it is no need to `set_xxx_float/double/int` - m.def("set_feed_variable_float", framework::SetFeedVariable); - m.def("set_feed_variable_double", framework::SetFeedVariable); - m.def("set_feed_variable_int", framework::SetFeedVariable); + m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); BindProgramDesc(m); diff --git a/python/paddle/v2/framework/executor.py b/python/paddle/v2/framework/executor.py index 8da5daad993..1adc10c233e 100644 --- a/python/paddle/v2/framework/executor.py +++ b/python/paddle/v2/framework/executor.py @@ -38,8 +38,7 @@ class Executor(object): inputs={'X': [feed_var]}, outputs={'Out': [out]}, attrs={'col': i}) - # FIXME - core.set_feed_variable_float(feed[name], feed_var.name, i) + core.set_feed_variable(feed[name], feed_var.name, i) fetch_var = global_block.create_var( name=fetch_var_name, diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/framework/tests/test_feed_fetch_method.py index 47eedddcb6f..8b9b44440d6 100644 --- a/python/paddle/v2/framework/tests/test_feed_fetch_method.py +++ b/python/paddle/v2/framework/tests/test_feed_fetch_method.py @@ -12,7 +12,7 @@ class TestFeedFetch(unittest.TestCase): input_tensor = core.LoDTensor([[0, 2, 4]]) input_tensor.set(input_array, place) - core.set_feed_variable_float(input_tensor, "feed", 0) + core.set_feed_variable(input_tensor, "feed", 0) output_tensor = core.get_fetch_variable("feed", 0) -- GitLab