From 5fddd99e18f3920ff0d8158fd4a9800d5566943e Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 17:20:22 +0800 Subject: [PATCH] move TEST from test_matrixCompare.cpp to cross_map_normal_op_test.cpp --- cmake/util.cmake | 1 + paddle/function/CMakeLists.txt | 35 +++-- paddle/function/FunctionTest.h | 102 +++++++++++++ paddle/function/TestMain.cpp | 22 +++ paddle/function/cross_map_normal_op_test.cpp | 71 +++++++++ paddle/math/tests/test_matrixCompare.cpp | 144 ------------------- 6 files changed, 221 insertions(+), 154 deletions(-) create mode 100644 paddle/function/FunctionTest.h create mode 100644 paddle/function/TestMain.cpp create mode 100644 paddle/function/cross_map_normal_op_test.cpp diff --git a/cmake/util.cmake b/cmake/util.cmake index 03734e7839..8a71b23c62 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -107,6 +107,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_test_main ${METRIC_LIBS} ${PROTOBUF_LIBRARY} ${LIBGLOG_LIBRARY} diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 8fad0e3ebd..0697842bbe 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -1,12 +1,27 @@ -file(GLOB FUNCTION_HEADERS . *.h) - -if(NOT WITH_GPU) - file(GLOB FUNCTION_SOURCES . *.cpp) - add_library(paddle_function STATIC ${FUNCTION_SOURCES}) -else() - file(GLOB FUNCTION_SOURCES . *.cpp *.cu) - cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +file(GLOB h_files . *_op.h) +file(GLOB cpp_files . *_op.cpp) + +list(APPEND h_files Function.h) +list(APPEND cpp_files Function.cpp) + +if(WITH_GPU) + file(GLOB cu_files . *_op_gpu.cu) + cuda_compile(cu_objs ${cu_files}) endif() -add_style_check_target(paddle_function ${FUNCTION_SOURCES}) -add_style_check_target(paddle_function ${FUNCTION_HEADERS}) +add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) + +add_library(paddle_test_main STATIC TestMain.cpp) + +if(WITH_GPU) + # TODO: + # file(GLOB test_files . *_op_test.cpp) + # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) + add_simple_unittest(cross_map_normal_op_test) +endif() + +add_style_check_target(paddle_function ${h_files}) +add_style_check_target(paddle_function ${cpp_files}) +if(WITH_GPU) + add_style_check_target(paddle_function ${cu_files}) +endif() diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h new file mode 100644 index 0000000000..a8c5e412bd --- /dev/null +++ b/paddle/function/FunctionTest.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "paddle/math/Vector.h" +#include "paddle/math/tests/TensorCheck.h" + +namespace paddle { + +class FunctionCompare { +public: + FunctionCompare(const std::string& name, const FuncConfig& config) + : cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), + gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { + cpu->init(config); + gpu->init(config); + } + + void cmpWithArg(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) { + // init cpu and gpu arguments + auto initArgs = [=]( + Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) { + for (auto arg : inArgs) { + size_t size = sizeof(real); + for (auto dim : arg.dims_) { + size *= dim; + } + cpuMemory.emplace_back(std::make_shared(size)); + gpuMemory.emplace_back(std::make_shared(size)); + cpuArgs.emplace_back( + Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_)); + gpuArgs.emplace_back( + Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_)); + + // will use an api to refactor this code. + CpuVector cpuVector(size / sizeof(real), + (real*)cpuArgs.back().getData()); + GpuVector gpuVector(size / sizeof(real), + (real*)gpuArgs.back().getData()); + cpuVector.uniform(0.001, 1); + gpuVector.copyFrom(cpuVector); + } + }; + initArgs(cpuInputs, gpuInputs, inputs); + initArgs(cpuOutputs, gpuOutputs, outputs); + initArgs(cpuInouts, gpuInouts, inouts); + + // function calculate + cpu->calc(cpuInputs, cpuOutputs, cpuInouts); + gpu->calc(gpuInputs, gpuOutputs, gpuInouts); + + // check outputs and inouts + auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) { + for (size_t i = 0; i < cpuArgs.size(); i++) { + auto cpu = cpuArgs[i]; + auto gpu = gpuArgs[i]; + size_t size = 1; + for (auto dim : cpu.dims_) { + size *= dim; + } + CpuVector cpuVector(size, (real*)cpu.getData()); + GpuVector gpuVector(size, (real*)gpu.getData()); + + autotest::TensorCheckErr(cpuVector, gpuVector); + } + }; + checkArgs(cpuOutputs, gpuOutputs); + checkArgs(cpuInouts, gpuInouts); + } + +protected: + std::shared_ptr cpu; + std::shared_ptr gpu; + std::vector cpuMemory; + std::vector gpuMemory; + Arguments cpuInputs; + Arguments cpuOutputs; + Arguments cpuInouts; + Arguments gpuInputs; + Arguments gpuOutputs; + Arguments gpuInouts; +}; + +} // namespace paddle + +using paddle::FunctionCompare; +using paddle::FuncConfig; +using paddle::Dims; +using paddle::Tensor; diff --git a/paddle/function/TestMain.cpp b/paddle/function/TestMain.cpp new file mode 100644 index 0000000000..3e14532d18 --- /dev/null +++ b/paddle/function/TestMain.cpp @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/utils/Util.h" + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/function/cross_map_normal_op_test.cpp b/paddle/function/cross_map_normal_op_test.cpp new file mode 100644 index 0000000000..22692691bd --- /dev/null +++ b/paddle/function/cross_map_normal_op_test.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +TEST(CrossMapNormal, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormal", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims)}, + {Tensor(nullptr, dims), Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} + +TEST(CrossMapNormalGrad, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormalGrad", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims)}, + {Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index c89b7ff490..440534e722 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1263,150 +1263,6 @@ TEST(Matrix, MaxOutFwdBwd) { } } -void testCrossMapNormalFwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - int width = imgSizeH * imgSizeW * channels; - CpuMatrix inputs(numSamples, width); - CpuMatrix denoms(numSamples, width); - CpuMatrix outputs(numSamples, width); - GpuMatrix inputsGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - GpuMatrix outputsGpu(numSamples, width); - - inputs.randomizeUniform(); - outputs.randomizeUniform(); - inputsGpu.copyFrom(inputs); - outputsGpu.copyFrom(outputs); - - FunctionBase* cpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - FunctionBase* gpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputs.getData(), dims)}, - {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, - {}); - - gpu->calc( - {Tensor(inputsGpu.getData(), dims)}, - {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, - {}); - - TensorCheckErr(outputs, outputsGpu); - TensorCheckErr(denoms, denomsGpu); -} - -TEST(Matrix, crossMapNormalFwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalFwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - -void testCrossMapNormalBwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - size_t width = imgSizeH * imgSizeW * channels; - - CpuMatrix inputsGrad(numSamples, width); - CpuMatrix inputsValue(numSamples, width); - CpuMatrix outputsGrad(numSamples, width); - CpuMatrix outputsValue(numSamples, width); - CpuMatrix denoms(numSamples, width); - - outputsGrad.randomizeUniform(); - denoms.randomizeUniform(); - inputsValue.randomizeUniform(); - outputsValue.randomizeUniform(); - inputsGrad.randomizeUniform(); - denoms.add(0.01); - - GpuMatrix inputsGradGpu(numSamples, width); - GpuMatrix inputsValueGpu(numSamples, width); - GpuMatrix outputsGradGpu(numSamples, width); - GpuMatrix outputsValueGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - - outputsGradGpu.copyFrom(outputsGrad); - denomsGpu.copyFrom(denoms); - inputsValueGpu.copyFrom(inputsValue); - outputsValueGpu.copyFrom(outputsValue); - inputsGradGpu.copyFrom(inputsGrad); - - FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputsValue.getData(), dims), - Tensor(outputsValue.getData(), dims), - Tensor(outputsGrad.getData(), dims), - Tensor(denoms.getData(), dims)}, - {Tensor(inputsGrad.getData(), dims)}, - {}); - - gpu->calc({Tensor(inputsValueGpu.getData(), dims), - Tensor(outputsValueGpu.getData(), dims), - Tensor(outputsGradGpu.getData(), dims), - Tensor(denomsGpu.getData(), dims)}, - {Tensor(inputsGradGpu.getData(), dims)}, - {}); - - TensorCheckErr(inputsGrad, inputsGradGpu); -} - -TEST(Matrix, crossMapNormalBwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalBwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); -- GitLab