diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h index d58736dcb1c64a81e84c570df755f86b611b4015..3ef70043d6a1520a0d583dabb7290259417706ce 100644 --- a/paddle/framework/feed_fetch_method.h +++ b/paddle/framework/feed_fetch_method.h @@ -21,7 +21,6 @@ limitations under the License. */ namespace paddle { namespace framework { -template void SetFeedVariable(const LoDTensor& input, const std::string& var_name, size_t index) { // If var_name Variable is not found in GlobalScope, a new variable will diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3455c82e6757b71a4c1b292da8fb170e1ace3074..94c9706f794794c9ba91cd71da660999e9cb93b9 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -460,10 +460,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); - //! FIXME: it is no need to `set_xxx_float/double/int` - m.def("set_feed_variable_float", framework::SetFeedVariable); - m.def("set_feed_variable_double", framework::SetFeedVariable); - m.def("set_feed_variable_int", framework::SetFeedVariable); + m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); BindProgramDesc(m); diff --git a/python/paddle/v2/framework/executor.py b/python/paddle/v2/framework/executor.py index 8da5daad993e9ceaff93b5271c30a3b260b7abcc..1adc10c233e465fef9f92551af354a92ab5069dd 100644 --- a/python/paddle/v2/framework/executor.py +++ b/python/paddle/v2/framework/executor.py @@ -38,8 +38,7 @@ class Executor(object): inputs={'X': [feed_var]}, outputs={'Out': [out]}, attrs={'col': i}) - # FIXME - core.set_feed_variable_float(feed[name], feed_var.name, i) + core.set_feed_variable(feed[name], feed_var.name, i) fetch_var = global_block.create_var( name=fetch_var_name, diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/framework/tests/test_feed_fetch_method.py index 47eedddcb6f47927ea3918d7f6c379c5710592c6..8b9b44440d6247c273240cf772682667dd911c8f 100644 --- a/python/paddle/v2/framework/tests/test_feed_fetch_method.py +++ b/python/paddle/v2/framework/tests/test_feed_fetch_method.py @@ -12,7 +12,7 @@ class TestFeedFetch(unittest.TestCase): input_tensor = core.LoDTensor([[0, 2, 4]]) input_tensor.set(input_array, place) - core.set_feed_variable_float(input_tensor, "feed", 0) + core.set_feed_variable(input_tensor, "feed", 0) output_tensor = core.get_fetch_variable("feed", 0)