提交 42f2dd40 编写于 作者: Y Yu Yang 提交者: GitHub

Unify `set_feed_variable` to one method (#4949)

上级 c532b967
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
void SetFeedVariable(const LoDTensor& input, const std::string& var_name, void SetFeedVariable(const LoDTensor& input, const std::string& var_name,
size_t index) { size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will // If var_name Variable is not found in GlobalScope, a new variable will
......
...@@ -460,10 +460,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -460,10 +460,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU); m.def("is_compile_gpu", IsCompileGPU);
//! FIXME: it is no need to `set_xxx_float/double/int` m.def("set_feed_variable", framework::SetFeedVariable);
m.def("set_feed_variable_float", framework::SetFeedVariable<float>);
m.def("set_feed_variable_double", framework::SetFeedVariable<double>);
m.def("set_feed_variable_int", framework::SetFeedVariable<int>);
m.def("get_fetch_variable", framework::GetFetchVariable); m.def("get_fetch_variable", framework::GetFetchVariable);
BindProgramDesc(m); BindProgramDesc(m);
......
...@@ -38,8 +38,7 @@ class Executor(object): ...@@ -38,8 +38,7 @@ class Executor(object):
inputs={'X': [feed_var]}, inputs={'X': [feed_var]},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'col': i}) attrs={'col': i})
# FIXME core.set_feed_variable(feed[name], feed_var.name, i)
core.set_feed_variable_float(feed[name], feed_var.name, i)
fetch_var = global_block.create_var( fetch_var = global_block.create_var(
name=fetch_var_name, name=fetch_var_name,
......
...@@ -12,7 +12,7 @@ class TestFeedFetch(unittest.TestCase): ...@@ -12,7 +12,7 @@ class TestFeedFetch(unittest.TestCase):
input_tensor = core.LoDTensor([[0, 2, 4]]) input_tensor = core.LoDTensor([[0, 2, 4]])
input_tensor.set(input_array, place) input_tensor.set(input_array, place)
core.set_feed_variable_float(input_tensor, "feed", 0) core.set_feed_variable(input_tensor, "feed", 0)
output_tensor = core.get_fetch_variable("feed", 0) output_tensor = core.get_fetch_variable("feed", 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册