From 1f5192a27b968a7980c2eead7b6885e66f09575a Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 6 Oct 2017 11:06:59 -0700 Subject: [PATCH] fix executor gpu unittest --- paddle/framework/executor.cc | 2 +- paddle/framework/executor_test.cc | 20 +++++++++++++++----- paddle/operators/fetch_op.cu | 2 +- paddle/platform/gpu_info.cc | 3 ++- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index ee0df039ac..c18ba049c8 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -30,7 +30,7 @@ Executor::Executor(const std::vector& places) { device_contexts_[i] = new platform::CPUDeviceContext( boost::get(places[i])); } else if (platform::is_gpu_place(places[i])) { -#ifdef PADDLE_WITH_GPU +#ifdef PADDLE_WITH_CUDA device_contexts_[i] = new platform::CUDADeviceContext( boost::get(places[i])); #else diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 5e327cc893..55e209628b 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -293,7 +293,7 @@ TEST_F(ExecutorTesterFeed, CPU) { delete executor; } -#ifdef PADDLE_WITH_GPU +#ifdef PADDLE_WITH_CUDA TEST_F(ExecutorTesterRandom, GPU) { std::vector places; GPUPlace gpu_place(0); @@ -315,10 +315,20 @@ TEST_F(ExecutorTesterFeed, GPU) { Executor* executor = new Executor(places); - // need to set feed variable before Executor::Run - set_feed_variable(inputs_); - executor->Run(pdesc_, GetScope()); - + // 3 mini-batch + for (int i = 0; i < 3; i++) { + // need to set feed variable before Executor::Run + std::cout << "start mini-batch " << i << std::endl; + set_feed_variable(inputs_); + executor->Run(pdesc_, GetScope()); + std::vector> result = get_fetch_variable(); + for (auto& vec : result) { + for (auto& num : vec) { + std::cout << num << " "; + } + std::cout << std::endl; + } + } delete executor; } #endif diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu index 2e24d3a8ad..ca39d24c79 100644 --- a/paddle/operators/fetch_op.cu +++ b/paddle/operators/fetch_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/feed_op.h" +#include "paddle/operators/fetch_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index 486dcd623a..aa76bb209d 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -43,7 +43,8 @@ int GetCurrentDeviceId() { } void SetDeviceId(int id) { - PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count"); + // TODO(qijun): find a better way to cache the cuda device count + PADDLE_ENFORCE(id < GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } -- GitLab