未验证 提交 1238706d 编写于 作者: Q QI JUN 提交者: GitHub

Refine unittest with setting gflags (#5476)

* add gflags for C++ unittest
上级 605b3e44
...@@ -227,8 +227,8 @@ function(cc_test TARGET_NAME) ...@@ -227,8 +227,8 @@ function(cc_test TARGET_NAME)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS}) add_executable(${TARGET_NAME} ${cc_test_SRCS})
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif() endif()
endfunction(cc_test) endfunction(cc_test)
...@@ -288,8 +288,8 @@ function(nv_test TARGET_NAME) ...@@ -288,8 +288,8 @@ function(nv_test TARGET_NAME)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} gtest gtest_main) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} gtest gtest_main) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
endif() endif()
endfunction(nv_test) endfunction(nv_test)
......
...@@ -81,18 +81,33 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { ...@@ -81,18 +81,33 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
} }
template <> template <>
void* Alloc<platform::GPUPlace>(platform::GPUPlace place, size_t size) { size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return GetGPUBuddyAllocator(place.device)->Alloc(size); return GetGPUBuddyAllocator(place.device)->Used();
} }
template <> template <>
void Free<platform::GPUPlace>(platform::GPUPlace place, void* p) { void* Alloc<platform::GPUPlace>(platform::GPUPlace place, size_t size) {
GetGPUBuddyAllocator(place.device)->Free(p); auto* buddy_allocator = GetGPUBuddyAllocator(place.device);
auto* ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) {
int cur_dev = platform::GetCurrentDeviceId();
platform::SetDeviceId(place.device);
size_t avail, total;
platform::GpuMemoryUsage(avail, total);
LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU "
<< place.device << ", available " << avail << " bytes";
LOG(WARNING) << "total " << total;
LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize();
LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize();
LOG(WARNING) << "GPU memory used: " << Used<platform::GPUPlace>(place);
platform::SetDeviceId(cur_dev);
}
return ptr;
} }
template <> template <>
size_t Used<platform::GPUPlace>(platform::GPUPlace place) { void Free<platform::GPUPlace>(platform::GPUPlace place, void* p) {
return GetGPUBuddyAllocator(place.device)->Used(); GetGPUBuddyAllocator(place.device)->Free(p);
} }
#endif #endif
......
...@@ -127,8 +127,3 @@ TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); } ...@@ -127,8 +127,3 @@ TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); }
TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); }
TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); } TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
...@@ -46,8 +46,3 @@ TEST(TensorToProto, Case2) { ...@@ -46,8 +46,3 @@ TEST(TensorToProto, Case2) {
EXPECT_EQ(t1[i], t[i]); EXPECT_EQ(t1[i], t[i]);
} }
} }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
...@@ -5,4 +5,6 @@ if(WITH_TESTING) ...@@ -5,4 +5,6 @@ if(WITH_TESTING)
add_dependencies(paddle_test_main paddle_proto ${external_project_dependencies}) add_dependencies(paddle_test_main paddle_proto ${external_project_dependencies})
add_library(paddle_test_util STATIC TestUtil.cpp) add_library(paddle_test_util STATIC TestUtil.cpp)
add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies}) add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies})
add_library(paddle_gtest_main STATIC paddle_gtest_main.cc)
add_dependencies(paddle_gtest_main paddle_memory gtest gflags)
endif() endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstring>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/memory/memory.h"
int main(int argc, char** argv) {
std::vector<char*> new_argv;
std::string gflags_env;
new_argv.push_back(argv[0]);
#ifdef PADDLE_WITH_CUDA
new_argv.push_back(
strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"));
#else
new_argv.push_back(strdup("--tryfromenv=use_pinned_memory"));
#endif
int new_argc = static_cast<int>(new_argv.size());
char** new_argv_address = new_argv.data();
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false);
testing::InitGoogleTest(&argc, argv);
paddle::memory::Used(paddle::platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
paddle::memory::Used(paddle::platform::GPUPlace(0));
#endif
return RUN_ALL_TESTS();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册