diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h new file mode 100644 index 0000000000000000000000000000000000000000..be96dc32670c97e3d1a8fa47c44a50965ffa528c --- /dev/null +++ b/paddle/framework/feed_fetch_method.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/scope.h" +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +template +void SetFeedVariable(const LoDTensor& input, const std::string& var_name, + size_t index) { + // If var_name Variable is not found in GlobalScope, a new variable will + // be created. + Variable* g_feed_value = GetGlobalScope().Var(var_name); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + if (index >= feed_inputs.size()) { + feed_inputs.resize(index + 1); + } + // shared data with input tensor + feed_inputs[index].ShareDataWith(input); + // set lod + feed_inputs[index].set_lod(input.lod()); +} + +LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) { + // If var_name Variable is not found in GlobalScope, a new variable will + // be created. + Variable* g_fetch_value = GetGlobalScope().Var(var_name); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + PADDLE_ENFORCE_LT(index, fetch_outputs.size()); + return fetch_outputs[index]; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7376245b536c90378b746fe94a023b4b7696e785..78eca28009e29ee8df6b5203ac36fdb020b5f273 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -77,65 +77,5 @@ class Scope { framework::Scope& GetGlobalScope(); -// template -// void SetFeedVariable(const std::vector& input, const Lod& lod, -// const std::vector& dims, -// const std::string& var_name, size_t index) { -// Variable* g_feed_value = GetGlobalScope().Var("var_name"); -// // feed variable holds vector -// auto& feed_inputs = -// *(g_feed_value->GetMutable< -// std::vector>()); -// if (index >= feed_inputs.size()) { -// feed_inputs.resize(index); -// } -// // copy tensor -// T* dst = feed_inputs[index].mutable_data(make_ddim(dims), -// platform::CPUPlace()); -// memcpy(dst, inputs[i].data(), inputs[i].size() * sizeof(T)); -// // copy lod -// feed_inputs[index].set_lod(lod); -// } - -template -void SetFeedVariable(const LoDTensor& input, const std::string& var_name, - size_t index) { - std::cout << "into SetFeedVariable" << std::endl; - std::cout << var_name << std::endl; - std::cout << index << std::endl; - Variable* g_feed_value = GetGlobalScope().Var(var_name); - auto& feed_inputs = - *(g_feed_value->GetMutable>()); - if (index >= feed_inputs.size()) { - feed_inputs.resize(index + 1); - } - // shared data with input tensor - feed_inputs[index].ShareDataWith(input); - // set lod - feed_inputs[index].set_lod(input.lod()); -} - -// template -// std::vector GetFetchVariable(const std::string& var_name, size_t index) { -// Variable* g_fetch_value = GetGlobalScope().Var(var_name); -// auto& fetch_outputs = -// *(g_fetch_value->GetMutable< -// std::vector>()); -// std::vector result; -// result.resize(fetch_outputs[index].numel()); -// memcpy(result.data(), fetch_outputs[i].data(), -// fetch_outputs[i].numel() * sizeof(T)); -// } - -template -LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) { - Variable* g_fetch_value = GetGlobalScope().Var(var_name); - auto& fetch_outputs = - *(g_fetch_value->GetMutable>()); - std::cout << "into GetFetchVariable" << std::endl; - PADDLE_ENFORCE_LT(index, fetch_outputs.size()); - return fetch_outputs[index]; -} - } // namespace framework } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 1c8d2cfd6129dd91006e8d24501def0ab3364dbd..983f19b30d7c452f7451fc89da1628fbde8dff1d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/pybind/protobuf.h" - +#include "paddle/pybind/pybind.h" #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" +#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/tensor_array.h" #include "paddle/operators/cond_op.h" @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "paddle/pybind/exception.h" -#include "paddle/pybind/pybind.h" +#include "paddle/pybind/protobuf.h" #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" @@ -403,12 +403,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); - m.def("set_feed_variable", SetFeedVariable); - // m.def("set_feed_variable", SetFeedVariable); - // m.def("set_feed_variable", SetFeedVariable); - m.def("get_fetch_variable", GetFetchVariable); - // m.def("get_fetch_variable", GetFetchVariable); - // m.def("get_fetch_variable", GetFetchVariable); + m.def("set_feed_variable_float", framework::SetFeedVariable); + m.def("set_feed_variable_double", framework::SetFeedVariable); + m.def("set_feed_variable_int", framework::SetFeedVariable); + m.def("get_fetch_variable", framework::GetFetchVariable); BindProgramDesc(m); BindBlockDesc(m); diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/framework/tests/test_feed_fetch_method.py index cd8eeee68f0c071baa39769bf1493db82b715809..47eedddcb6f47927ea3918d7f6c379c5710592c6 100644 --- a/python/paddle/v2/framework/tests/test_feed_fetch_method.py +++ b/python/paddle/v2/framework/tests/test_feed_fetch_method.py @@ -2,49 +2,29 @@ import paddle.v2.framework.core as core import unittest import numpy as np -# class TestFeedFetch(unittest.TestCase): -# def test_feed_fetch(self): -# place = core.CPUPlace() -# input_tensor = core.LoDTensor([[0, 2, 4]]) -# input_tensor.set_dims([4, 4, 6]) -# input_tensor.alloc_int(place) -# input_array = np.array(input_tensor) -# input_array[0, 0, 0] = 3 -# input_array[3, 3, 5] = 10 -# input_tensor.set(input_array, place) - -# core.set_feed_variable(input_tensor, "feed", 0) - -# output_tensor = core.get_fetch_variable("feed", 0) -# print type(output_tensor) - -# output_lod = output_tensor.lod() -# print type(output_lod) -# print output_lod[0] -# print output_lod[0][0] -# print output_lod[0][1] -# print output_lod[0][2] -# # self.assertEqual(0, output_lod[0][0]) -# # self.assertEqual(0, output_lod[0][0]) -# # self.assertEqual(2, output_lod[0][1]) -# # self.assertEqual(4, output_lod[0][2]) - -# # output_array = np.array(output_tensor) -# # self.assertEqual(3, output_array[0, 0, 0]) -# # self.assertEqual(10, output_array[3, 3, 5]); - class TestFeedFetch(unittest.TestCase): def test_feed_fetch(self): place = core.CPUPlace() - input_tensor = core.LoDTensor([[0, 2, 4]]) - input_tensor.set_dims([4, 4, 6]) - input_tensor.alloc_float(place) - input_array = np.array(input_tensor) + input_array = np.ones((4, 4, 6)).astype("float32") input_array[0, 0, 0] = 3 input_array[3, 3, 5] = 10 + input_tensor = core.LoDTensor([[0, 2, 4]]) input_tensor.set(input_array, place) + core.set_feed_variable_float(input_tensor, "feed", 0) + + output_tensor = core.get_fetch_variable("feed", 0) + + output_lod = output_tensor.lod() + self.assertEqual(0, output_lod[0][0]) + self.assertEqual(2, output_lod[0][1]) + self.assertEqual(4, output_lod[0][2]) + + output_array = np.array(output_tensor) + self.assertEqual(3, output_array[0, 0, 0]) + self.assertEqual(10, output_array[3, 3, 5]) + if __name__ == "__main__": unittest.main()