From 7c0facd1958faa2e48d5764b3274459124e51c53 Mon Sep 17 00:00:00 2001 From: qijun Date: Sun, 15 Oct 2017 22:01:51 -0700 Subject: [PATCH] init --- paddle/framework/scope.h | 62 +++++++++++++++++++ paddle/pybind/pybind.cc | 6 ++ .../framework/tests/test_feed_fetch_method.py | 50 +++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 python/paddle/v2/framework/tests/test_feed_fetch_method.py diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index a7fce3514b..7376245b53 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -18,8 +18,10 @@ limitations under the License. */ #include #include +#include "paddle/framework/lod_tensor.h" #include "paddle/framework/variable.h" #include "paddle/platform/macros.h" +#include "paddle/platform/place.h" namespace paddle { namespace framework { @@ -75,5 +77,65 @@ class Scope { framework::Scope& GetGlobalScope(); +// template +// void SetFeedVariable(const std::vector& input, const Lod& lod, +// const std::vector& dims, +// const std::string& var_name, size_t index) { +// Variable* g_feed_value = GetGlobalScope().Var("var_name"); +// // feed variable holds vector +// auto& feed_inputs = +// *(g_feed_value->GetMutable< +// std::vector>()); +// if (index >= feed_inputs.size()) { +// feed_inputs.resize(index); +// } +// // copy tensor +// T* dst = feed_inputs[index].mutable_data(make_ddim(dims), +// platform::CPUPlace()); +// memcpy(dst, inputs[i].data(), inputs[i].size() * sizeof(T)); +// // copy lod +// feed_inputs[index].set_lod(lod); +// } + +template +void SetFeedVariable(const LoDTensor& input, const std::string& var_name, + size_t index) { + std::cout << "into SetFeedVariable" << std::endl; + std::cout << var_name << std::endl; + std::cout << index << std::endl; + Variable* g_feed_value = GetGlobalScope().Var(var_name); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + if (index >= feed_inputs.size()) { + feed_inputs.resize(index + 1); + } + // shared data with input tensor + feed_inputs[index].ShareDataWith(input); + // set lod + feed_inputs[index].set_lod(input.lod()); +} + +// template +// std::vector GetFetchVariable(const std::string& var_name, size_t index) { +// Variable* g_fetch_value = GetGlobalScope().Var(var_name); +// auto& fetch_outputs = +// *(g_fetch_value->GetMutable< +// std::vector>()); +// std::vector result; +// result.resize(fetch_outputs[index].numel()); +// memcpy(result.data(), fetch_outputs[i].data(), +// fetch_outputs[i].numel() * sizeof(T)); +// } + +template +LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) { + Variable* g_fetch_value = GetGlobalScope().Var(var_name); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + std::cout << "into GetFetchVariable" << std::endl; + PADDLE_ENFORCE_LT(index, fetch_outputs.size()); + return fetch_outputs[index]; +} + } // namespace framework } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index b143cb9f59..a4d14c9303 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -394,6 +394,12 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); + m.def("set_feed_variable", SetFeedVariable); + // m.def("set_feed_variable", SetFeedVariable); + // m.def("set_feed_variable", SetFeedVariable); + m.def("get_fetch_variable", GetFetchVariable); + // m.def("get_fetch_variable", GetFetchVariable); + // m.def("get_fetch_variable", GetFetchVariable); BindProgramDesc(m); BindBlockDesc(m); diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/framework/tests/test_feed_fetch_method.py new file mode 100644 index 0000000000..cd8eeee68f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_feed_fetch_method.py @@ -0,0 +1,50 @@ +import paddle.v2.framework.core as core +import unittest +import numpy as np + +# class TestFeedFetch(unittest.TestCase): +# def test_feed_fetch(self): +# place = core.CPUPlace() +# input_tensor = core.LoDTensor([[0, 2, 4]]) +# input_tensor.set_dims([4, 4, 6]) +# input_tensor.alloc_int(place) +# input_array = np.array(input_tensor) +# input_array[0, 0, 0] = 3 +# input_array[3, 3, 5] = 10 +# input_tensor.set(input_array, place) + +# core.set_feed_variable(input_tensor, "feed", 0) + +# output_tensor = core.get_fetch_variable("feed", 0) +# print type(output_tensor) + +# output_lod = output_tensor.lod() +# print type(output_lod) +# print output_lod[0] +# print output_lod[0][0] +# print output_lod[0][1] +# print output_lod[0][2] +# # self.assertEqual(0, output_lod[0][0]) +# # self.assertEqual(0, output_lod[0][0]) +# # self.assertEqual(2, output_lod[0][1]) +# # self.assertEqual(4, output_lod[0][2]) + +# # output_array = np.array(output_tensor) +# # self.assertEqual(3, output_array[0, 0, 0]) +# # self.assertEqual(10, output_array[3, 3, 5]); + + +class TestFeedFetch(unittest.TestCase): + def test_feed_fetch(self): + place = core.CPUPlace() + input_tensor = core.LoDTensor([[0, 2, 4]]) + input_tensor.set_dims([4, 4, 6]) + input_tensor.alloc_float(place) + input_array = np.array(input_tensor) + input_array[0, 0, 0] = 3 + input_array[3, 3, 5] = 10 + input_tensor.set(input_array, place) + + +if __name__ == "__main__": + unittest.main() -- GitLab