diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h new file mode 100644 index 0000000000000000000000000000000000000000..826d180bfc5445224a8d9292f06eeb58d9a46b29 --- /dev/null +++ b/paddle/framework/feed_fetch_method.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/scope.h" +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +template +void SetFeedVariable(const LoDTensor& input, const std::string& var_name, + size_t index) { + // If var_name Variable is not found in GlobalScope, a new variable will + // be created. + Variable* g_feed_value = GetGlobalScope().Var(var_name); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + if (index >= feed_inputs.size()) { + feed_inputs.resize(index + 1); + } + // shared data with input tensor + feed_inputs[index].ShareDataWith(input); + // set lod + feed_inputs[index].set_lod(input.lod()); +} + +LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) { + // Since we want to fetch LodTensor from a variable, the variable must + // be created alreadly. + Variable* g_fetch_value = GetGlobalScope().FindVar(var_name); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + PADDLE_ENFORCE_LT(index, fetch_outputs.size()); + return fetch_outputs[index]; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index afc80b25b185509739edcd0d06817aa1d507a8ec..008a9441db408e2db99ec86af645344f376d0b8f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" +#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/tensor_array.h" #include "paddle/operators/cond_op.h" @@ -403,6 +404,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); + m.def("set_feed_variable_float", framework::SetFeedVariable); + m.def("set_feed_variable_double", framework::SetFeedVariable); + m.def("set_feed_variable_int", framework::SetFeedVariable); + m.def("get_fetch_variable", framework::GetFetchVariable); BindProgramDesc(m); BindBlockDesc(m); diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/framework/tests/test_feed_fetch_method.py new file mode 100644 index 0000000000000000000000000000000000000000..47eedddcb6f47927ea3918d7f6c379c5710592c6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_feed_fetch_method.py @@ -0,0 +1,30 @@ +import paddle.v2.framework.core as core +import unittest +import numpy as np + + +class TestFeedFetch(unittest.TestCase): + def test_feed_fetch(self): + place = core.CPUPlace() + input_array = np.ones((4, 4, 6)).astype("float32") + input_array[0, 0, 0] = 3 + input_array[3, 3, 5] = 10 + input_tensor = core.LoDTensor([[0, 2, 4]]) + input_tensor.set(input_array, place) + + core.set_feed_variable_float(input_tensor, "feed", 0) + + output_tensor = core.get_fetch_variable("feed", 0) + + output_lod = output_tensor.lod() + self.assertEqual(0, output_lod[0][0]) + self.assertEqual(2, output_lod[0][1]) + self.assertEqual(4, output_lod[0][2]) + + output_array = np.array(output_tensor) + self.assertEqual(3, output_array[0, 0, 0]) + self.assertEqual(10, output_array[3, 3, 5]) + + +if __name__ == "__main__": + unittest.main()