From eb2f7ed21bae0020e5ca36c80701f0337e4028be Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Fri, 2 Nov 2018 15:19:12 +0800 Subject: [PATCH] refine tests. test=develop --- cmake/external/threadpool.cmake | 1 - paddle/fluid/framework/data_type.h | 41 ---- paddle/fluid/framework/executor.cc | 1 - .../framework/ir/attention_lstm_fuse_pass.cc | 21 +- paddle/fluid/inference/api/api_impl.cc | 2 - .../inference/api/demo_ci/CMakeLists.txt | 67 +++--- .../inference/api/demo_ci/inference_icnet.cc | 219 +++++++----------- .../inference/api/demo_ci/inference_icnet.h | 21 -- .../api/demo_ci/real_data_icnet_tester.cc | 125 ---------- .../api/demo_ci/thread_icnet_test.cc | 146 ------------ paddle/fluid/operators/conv_op.cc | 47 +--- 11 files changed, 128 insertions(+), 563 deletions(-) delete mode 100644 paddle/fluid/inference/api/demo_ci/inference_icnet.h delete mode 100644 paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc delete mode 100644 paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc diff --git a/cmake/external/threadpool.cmake b/cmake/external/threadpool.cmake index 21527fe538..0159815fed 100644 --- a/cmake/external/threadpool.cmake +++ b/cmake/external/threadpool.cmake @@ -3,7 +3,6 @@ INCLUDE(ExternalProject) SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool) SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR}/src/extern_threadpool) INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR}) -message("Debug" ${THREADPOOL_INCLUDE_DIR}) ExternalProject_Add( extern_threadpool diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index f3f8d6cce6..d5be43b33e 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -25,7 +25,6 @@ namespace framework { extern proto::VarType::Type ToDataType(std::type_index type); extern std::type_index ToTypeIndex(proto::VarType::Type type); -#if !defined(_MSC_VER) template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { switch (type) { @@ -60,46 +59,6 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { PADDLE_THROW("Not supported %d", type); } } -#else -// the msvc compiler do not implement two-stage name lookup correctly. -template -inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { - switch (type) { - case proto::VarType::FP16: - visitor.template apply(); - break; - case proto::VarType::FP32: - visitor.template apply(); - break; - case proto::VarType::FP64: - visitor.template apply(); - break; - case proto::VarType::INT32: - visitor.template apply(); - break; - case proto::VarType::INT64: - visitor.template apply(); - break; - case proto::VarType::BOOL: - visitor.template apply(); - break; - case proto::VarType::UINT8: - visitor.template apply(); - break; - case proto::VarType::INT16: - visitor.template apply(); - break; - default: - PADDLE_THROW("Not supported %d", type); - } -} - -template -void* AnyCast(const InT* t) { - return static_cast(const_cast(t)); -} - -#endif // _WIN32 extern std::string DataTypeToString(const proto::VarType::Type type); extern size_t SizeOfType(std::type_index type); diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 9ab1d1fa28..7d5551c7e6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -337,7 +337,6 @@ std::unique_ptr Executor::Prepare( new ExecutorPrepareContext(program, block_id)); PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); auto& block = program.Block(block_id); - int counter = 0; for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); } diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index b5aa9c8ccc..6090f1fe76 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -11,10 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h" +#include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -212,12 +211,12 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0, VLOG(3) << "LSTMWeight resized to " << out->dims(); float* out_data = out->mutable_data(platform::CPUPlace()); - std::array tensors = { - W_forget_w0.data(), W_input_w0.data(), - W_output_w0.data(), W_cell_w0.data()}; - std::array tensors1 = { - W_forget_w1.data(), W_input_w1.data(), - W_output_w1.data(), W_cell_w1.data()}; + std::array tensors( + {{W_forget_w0.data(), W_input_w0.data(), + W_output_w0.data(), W_cell_w0.data()}}); + std::array tensors1( + {{W_forget_w1.data(), W_input_w1.data(), + W_output_w1.data(), W_cell_w1.data()}}); for (int row = 0; row < D; row++) { for (int col = 0; col < 4; col++) { @@ -239,9 +238,9 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0, void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, const LoDTensor& B_output, const LoDTensor& B_cell, LoDTensor* out) { - std::array tensors = { - B_forget.data(), B_input.data(), B_output.data(), - B_cell.data()}; + std::array tensors( + {{B_forget.data(), B_input.data(), B_output.data(), + B_cell.data()}}); PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1); int D = B_forget.dims()[0]; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index eea5689da6..27f272f2d8 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -94,8 +94,6 @@ bool NativePaddlePredictor::Init( // All parameters are saved in a single file. // The file names should be consistent with that used // in Python API `fluid.io.save_inference_model`. - auto exe = executor_.get(); - auto sc = scope_.get(); inference_program_ = paddle::inference::Load( executor_.get(), scope_.get(), config_.prog_file, config_.param_file); diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 7aa95291b3..a742ba71ee 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -6,13 +6,13 @@ option(WITH_STATIC_LIB "Compile demo with static/shared library, default use sta option(USE_TENSORRT "Compile demo with TensorRT." OFF) macro(safe_set_static_flag) - foreach(flag_var - CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE - CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) - if(${flag_var} MATCHES "/MD") - string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") - endif(${flag_var} MATCHES "/MD") - endforeach(flag_var) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) endmacro() if (WIN32) @@ -42,7 +42,7 @@ if(WITH_GPU) # default gpu path set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") else() if(CUDA_LIB STREQUAL "") - set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") + set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") endif() endif(NOT WIN32) endif() @@ -53,9 +53,9 @@ include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") if (NOT WIN32) -include_directories("${PADDLE_LIB}/third_party/install/snappy/include") -include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") -include_directories("${PADDLE_LIB}/third_party/install/zlib/include") + include_directories("${PADDLE_LIB}/third_party/install/snappy/include") + include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") + include_directories("${PADDLE_LIB}/third_party/install/zlib/include") endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") @@ -63,15 +63,15 @@ include_directories("${PADDLE_LIB}/third_party/eigen3") if (NOT WIN32) if (USE_TENSORRT AND WITH_GPU) - include_directories("${TENSORRT_INCLUDE_DIR}") - link_directories("${TENSORRT_LIB_DIR}") + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") endif() endif(NOT WIN32) if (NOT WIN32) -link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") -link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") -link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") + link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") + link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") + link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") endif(NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") @@ -80,18 +80,12 @@ link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/paddle/lib") -# add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) - # add_library(${DEMO_NAME} ${DEMO_NAME}.cc) -add_executable(real_data_icnet_tester real_data_icnet_tester.cc) - -# add_library(${DEMO_NAME} SHARED ${DEMO_NAME}.cc) -# add_executable(test test.cc) -add_executable(thread_icnet_test thread_icnet_test.cc) +add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} - ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") if(EXISTS ${MKLDNN_PATH}) include_directories("${MKLDNN_PATH}/include") @@ -104,25 +98,25 @@ endif() # Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a if(WITH_STATIC_LIB) set(DEPS - ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) + ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) else() set(DEPS - ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) + ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() if (NOT WIN32) -set(EXTERNAL_LIB "-lrt -ldl -lpthread") -set(DEPS ${DEPS} + set(EXTERNAL_LIB "-lrt -ldl -lpthread") + set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) else() -set(DEPS ${DEPS} + set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} ${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf ${EXTERNAL_LIB}) -# NOTE(dzhwinter) shlwapi is deprecated. -set(DEPS ${DEPS} libcmt shlwapi) + # NOTE(dzhwinter) shlwapi will be deprecated. + set(DEPS ${DEPS} libcmt shlwapi) endif(NOT WIN32) if(WITH_GPU) @@ -134,14 +128,9 @@ if(WITH_GPU) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) - set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) - set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) endif() endif() -target_link_libraries(real_data_icnet_tester ${DEPS}) - -# target_link_libraries(${DEMO_NAME} ${DEPS}) -# target_link_libraries(test ${DEMO_NAME} ) -target_link_libraries(thread_icnet_test ${DEPS}) -# target_compile_definitions(${DEMO_NAME} PRIVATE "API_DEFINITION") +target_link_libraries(${DEMO_NAME} ${DEPS}) diff --git a/paddle/fluid/inference/api/demo_ci/inference_icnet.cc b/paddle/fluid/inference/api/demo_ci/inference_icnet.cc index 8b16351604..88e220c0b6 100644 --- a/paddle/fluid/inference/api/demo_ci/inference_icnet.cc +++ b/paddle/fluid/inference/api/demo_ci/inference_icnet.cc @@ -11,152 +11,89 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include -#include -#include +#define GOOGLE_GLOG_DLL_DECL +#include +#include +#include // NOLINT #include -#include -#include -#include -#include - -#include "paddle/fluid/inference/api/paddle_inference_api.h" -#include "inference_icnet.h" - -// 数据格式 -// "\t predictor; - struct Record - { - std::vector data; - std::vector shape; - }; - - const int C = 3; // image channel - const int H = 449; // image height - const int W = 581; // image width - - using Time = decltype(std::chrono::high_resolution_clock::now()); - - Time time() { return std::chrono::high_resolution_clock::now(); }; - - double time_diff(Time t1, Time t2) { - typedef std::chrono::microseconds ms; - auto diff = t2 - t1; - ms counter = std::chrono::duration_cast(diff); - return counter.count() / 1000.0; - } - - static void split(const std::string& str, char sep, - std::vector* pieces) { - pieces->clear(); - if (str.empty()) { - return; - } - size_t pos = 0; - size_t next = str.find(sep, pos); - while (next != std::string::npos) { - pieces->push_back(str.substr(pos, next - pos)); - pos = next + 1; - next = str.find(sep, pos); - } - if (!str.substr(pos).empty()) { - pieces->push_back(str.substr(pos)); - } - } - - Record ProcessALine(const std::string& line) { - std::vector columns; - split(line, '\t', &columns); - - Record record; - std::vector data_strs; - split(columns[0], ' ', &data_strs); - for (auto& d : data_strs) { - record.data.push_back(std::stof(d)); - } - - std::vector shape_strs; - split(columns[1], ' ', &shape_strs); - for (auto& s : shape_strs) { - record.shape.push_back(std::stoi(s)); - } - return record; - } - -public: - Predictor (const char* prog_file, - const char* param_file, const float fraction_of_gpu_memory, - const bool use_gpu, const int device) { - - NativeConfig config; - config.prog_file = prog_file; - config.param_file = param_file; - config.fraction_of_gpu_memory = fraction_of_gpu_memory; - config.use_gpu = use_gpu; - config.device = device; - - predictor = CreatePaddlePredictor(config); - } - - void predict(float* input, const int channel, const int height, const int width, - int64_t** output, int* output_length, int batch_size) { - std::vector data; - int intput_length = channel * height * width * batch_size; - for (int i = 0; i < intput_length; i++) { - data.push_back(*((float*)input + i)); - } - - // initialize the input data - PaddleTensor tensor; - tensor.shape = std::vector({ batch_size, channel, height, width }); - tensor.data.Resize(sizeof(float) * batch_size * channel * height * width); - std::copy(data.begin(), data.end(), static_cast(tensor.data.data())); - - tensor.dtype = PaddleDType::FLOAT32; - std::vector paddle_tensor_feeds(1, tensor); - - // initialize the output data - PaddleTensor tensor_out; - std::vector outputs(1, tensor_out); - predictor->Run(paddle_tensor_feeds, &outputs, batch_size); - *output_length = (int)outputs[0].data.length(); - std::memcpy(static_cast(*output), outputs[0].data.data(), outputs[0].data.length()); - int64_t sum_out = 0; - for(int i=0; i < outputs[0].data.length()/sizeof(int64_t); ++i) { - int64_t item = static_cast(outputs[0].data.data())[i]; - sum_out += item; - if (item != 0) { - std::cout << item << std::endl; - } - } +#include +#include // NOLINT +#include +#include "paddle/fluid/inference/paddle_inference_api.h" + +namespace paddle { + +NativeConfig GetConfig() { + NativeConfig config; + config.prog_file = "hs_lb_without_bn_cudnn/__model__"; + config.param_file = "hs_lb_without_bn_cudnn/__params__"; + config.fraction_of_gpu_memory = 0.0; + config.use_gpu = true; + config.device = 0; + return config; +} - std::cout << "sum_out" << sum_out << std::endl; - } -}; +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time TimeNow() { return std::chrono::high_resolution_clock::now(); } +double TimeDiff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} -API_REFERENCE void * init_predictor(const char* prog_file, - const char* param_file, const float fraction_of_gpu_memory, - const bool use_gpu, const int device) { - return new Predictor(prog_file, param_file, fraction_of_gpu_memory, use_gpu, device); +std::vector PrepareData() { + int height = 449; + int width = 581; + std::vector data; + for (int i = 0; i < 3 * height * width; ++i) { + data.push_back(0.0); + } + PaddleTensor tensor; + tensor.shape = std::vector({batch_size, 3, height, width}); + tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width); + std::copy(data.begin(), data.end(), static_cast(tensor.data.data())); + tensor.dtype = PaddleDType::FLOAT32; + std::vector paddle_tensor_feeds(1, tensor); + return std::move(paddle_tensor_feeds); } -API_REFERENCE void predict(void* handle, float* input, const int channel, const int height, const int width, - int64_t** output, int* output_length, int batch_size) { - assert(handle != nullptr); - ((Predictor*)handle)->predict(input, channel, height, width, output, output_length, batch_size); +void TestNaive(int batch_size, int thread_num) { + NativeConfig config = GetConfig(); + + int num_jobs = thread_num; // parallel jobs. + constexpr int epoches = 10; // each job run epoches. + std::vector threads; + std::vector> predictors; + for (int tid = 0; tid < num_jobs; ++tid) { + auto& pred = CreatePaddlePredictor(config); + predictors.emplace_back(std::move(pred)); + } + + auto time1 = TimeNow(); + for (int tid = 0; tid < num_jobs; ++tid) { + threads.emplace_back([&, tid]() { + auto& predictor = predictors[tid]; + PaddleTensor tensor_out; + std::vector outputs(1, tensor_out); + for (size_t i = 0; i < epoches; i++) { + ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); + VLOG(3) << "tid : " << tid << " run: " << i << "finished"; + ASSERT_EQ(outputs.size(), 1UL); + } + }); + } + for (int i = 0; i < num_jobs; ++i) { + threads[i].join(); + } + auto time2 = TimeNow(); + VLOG(3) << "Thread num " << thread_num << "total time cost" + << (time2 - time1); } +} // namespace paddle -API_REFERENCE void destory_predictor(void *handle) { - if (handle) { - delete handle; - handle = nullptr; - } +int main(int argc, char** argv) { + paddle::TestNaive(1, 1); // single thread. + paddle::TestNaive(1, 5); // 5 threads. + return 0; } diff --git a/paddle/fluid/inference/api/demo_ci/inference_icnet.h b/paddle/fluid/inference/api/demo_ci/inference_icnet.h deleted file mode 100644 index b2657e7988..0000000000 --- a/paddle/fluid/inference/api/demo_ci/inference_icnet.h +++ /dev/null @@ -1,21 +0,0 @@ - -#ifdef _WIN32 -#ifdef inference_icnet_EXPORTS -#define API_REFERENCE extern "C" __declspec(dllexport) -#else -#define API_REFERENCE extern "C" __declspec(dllimport) -#endif -#else -#define API_REFERENCE -#endif - -//API_REFERENCE void * init_predictor(); -//API_REFERENCE void destory_predictor(void *handle); -//API_REFERENCE void predict(void *handle, int n); - -API_REFERENCE void * init_predictor(const char* prog_file, - const char* param_file, const float fraction_of_gpu_memory, - const bool use_gpu, const int device); -API_REFERENCE void predict(void* handle, float* input, const int channel, const int height, - const int width, int64_t** output, int* output_length, int batch_size); -API_REFERENCE void destory_predictor(void *handle); diff --git a/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc b/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc deleted file mode 100644 index 5553d37355..0000000000 --- a/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#define GOOGLE_GLOG_DLL_DECL -#include -#include -#include -#include -#include -#include "paddle/fluid/inference/paddle_inference_api.h" - -namespace paddle { -NativeConfig GetConfig() { - NativeConfig config; - - // config.model_dir = FLAGS_dirname; - config.prog_file = "hs_lb_without_bn_cudnn/__model__"; - config.param_file = "hs_lb_without_bn_cudnn/__params__"; - // config.prog_file = "hs_lb_without_bn_cuda/__model__"; - // config.param_file = "hs_lb_without_bn_cuda/__params__"; - config.fraction_of_gpu_memory = 0.0; - config.use_gpu = true; - config.device = 0; - return config; -} - -using Time = decltype(std::chrono::high_resolution_clock::now()); -Time time() { return std::chrono::high_resolution_clock::now(); }; -double time_diff(Time t1, Time t2) { - typedef std::chrono::microseconds ms; - auto diff = t2 - t1; - ms counter = std::chrono::duration_cast(diff); - return counter.count() / 1000.0; -} - -void test_naive(int batch_size) { - NativeConfig config = GetConfig(); - auto predictor = CreatePaddlePredictor(config); - int height = 449; - int width = 581; - - // =============read file list ============= - std::ifstream infile("new_file.list"); - std::string temp_s; - std::vector all_files; - while (!infile.eof()) { - infile >> temp_s; - all_files.push_back(temp_s); - } - - // size_t file_num = all_files.size(); - infile.close(); - // =============read file list ============= - for (size_t f_k = 0; f_k < 1; f_k++) { - std::ifstream in_img(all_files[f_k]); - std::cout << all_files[f_k] << std::endl; - float temp_v; - - float sum_n = 0.0; - std::vector data; - while (!in_img.eof()) { - in_img >> temp_v; - data.push_back(float(temp_v)); - // std::cout << temp_v << " "; - sum_n += temp_v; - } - - in_img.close(); - std::cout << "sum: " << sum_n << std::endl; - - PaddleTensor tensor; - tensor.shape = std::vector({batch_size, 3, height, width}); - tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width); - std::copy(data.begin(), data.end(), - static_cast(tensor.data.data())); - tensor.dtype = PaddleDType::FLOAT32; - std::vector paddle_tensor_feeds(1, tensor); - PaddleTensor tensor_out; - - std::vector outputs(1, tensor_out); - // predictor->Run(paddle_tensor_feeds, &outputs, batch_size); - std::cout << "start predict123:" << std::endl; - auto time1 = time(); - int steps = 100; - for (size_t i = 0; i < steps; i++) { - if (i == 5) time1 = time(); - predictor->Run(paddle_tensor_feeds, &outputs, batch_size); - } - - auto time2 = time(); - std::ofstream ofresult("naive_test_result.txt", std::ios::app); - - std::cout << "batch: " << batch_size - << " predict cost: " << time_diff(time1, time2) / steps << "ms" - << std::endl; - std::cout << outputs.size() << std::endl; - int64_t* data_o = static_cast(outputs[0].data.data()); - int64_t sum_out = 0; - for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); ++j) { - ofresult << std::to_string(data_o[j]) << " "; - sum_out += data_o[j]; - } - std::cout << "sum_out " << sum_out << std::endl; - ofresult << std::endl; - ofresult.close(); - } -} - -} // namespace paddle - -int main(int argc, char** argv) { - // google::ParseCommandLineFlags(&argc, &argv, true); - paddle::test_naive(1 << 0); - return 0; -} diff --git a/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc b/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc deleted file mode 100644 index e1ce46b3bb..0000000000 --- a/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#define GOOGLE_GLOG_DLL_DECL - -#include -#include -//#include -#include -#include -#include -#include // NOLINT -#include -#include "paddle/fluid/inference/api/paddle_inference_api.h" - -#define ASSERT_TRUE(x) x -#define ASSERT_EQ(x, y) assert(x == y) - -// DEFINE_string(dirname, "./LB_icnet_model", -// "Directory of the inference model."); -namespace paddle { -NativeConfig GetConfig() { - NativeConfig config; - config.prog_file = "./hs_lb_without_bn_cuda/__model__"; - config.param_file = "./hs_lb_without_bn_cuda/__params__"; - config.fraction_of_gpu_memory = 0.0; - config.use_gpu = true; - config.device = 0; - return config; -} - -using Time = decltype(std::chrono::high_resolution_clock::now()); -Time time() { return std::chrono::high_resolution_clock::now(); }; -double time_diff(Time t1, Time t2) { - typedef std::chrono::microseconds ms; - auto diff = t2 - t1; - ms counter = std::chrono::duration_cast(diff); - return counter.count() / 1000.0; -} - -void test_naive(int batch_size, std::string model_path) { - NativeConfig config = GetConfig(); - int height = 449; - int width = 581; - std::vector data; - for (int i = 0; i < 3 * height * width; ++i) { - data.push_back(0.0); - } - - // read data - // std::ifstream infile("new_file.list"); - // std::string temp_s; - // std::vector all_files; - // while (!infile.eof()) { - // infile >> temp_s; - // all_files.push_back(temp_s); - // } - - // // size_t file_num = all_files.size(); - // infile.close(); - // // =============read file list ============= - // for (size_t f_k = 0; f_k < 1; f_k++) { - // std::ifstream in_img(all_files[f_k]); - // std::cout << all_files[f_k] << std::endl; - // float temp_v; - - // float sum_n = 0.0; - // std::vector data; - // while (!in_img.eof()) { - // in_img >> temp_v; - // data.push_back(float(temp_v)); - - // sum_n += temp_v; - // } - // in_img.close(); - // std::cout << "sum: " << sum_n << std::endl; - - PaddleTensor tensor; - tensor.shape = std::vector({batch_size, 3, height, width}); - tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width); - std::copy(data.begin(), data.end(), static_cast(tensor.data.data())); - tensor.dtype = PaddleDType::FLOAT32; - std::vector paddle_tensor_feeds(1, tensor); - - constexpr int num_jobs = 5; // each job run 1 batch - std::vector threads; - // using PtrPred = std::vector>; - std::vector> predictors; - for (int tid = 0; tid < num_jobs; ++tid) { - auto& pred = CreatePaddlePredictor(config); - predictors.emplace_back(std::move(pred)); - } - - using namespace std::chrono_literals; - // std::this_thread::sleep_for(std::chrono::seconds(20)); - std::cout << "before start predict"; - - int epoches = 100000; - for (int tid = 0; tid < num_jobs; ++tid) { - threads.emplace_back([&, tid]() { - // auto predictor = CreatePaddlePredictor(config); - auto& predictor = predictors[tid]; - // auto& predictor = predictors[tid]; - // auto predictor = preds[tid]; - // std::this_thread::sleep_for(std::chrono::seconds(20)); - PaddleTensor tensor_out; - std::vector outputs(1, tensor_out); - for (size_t i = 0; i < epoches; i++) { - ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); - VLOG(0) << "tid : " << tid << " run: " << i << "finished"; - // std::cout <<"tid : " << tid << " run: " << i << "finished" << - // std::endl; - ASSERT_EQ(outputs.size(), 1UL); - // int64_t* data_o = static_cast(outputs[0].data.data()); - // int64_t sum_out = 0; - // for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); - // ++j) { - // sum_out += data_o[j]; - // } - // std::cout << "tid : " << tid << "pass : " << i << " " << sum_out - // << std::endl; - } - }); - } - for (int i = 0; i < num_jobs; ++i) { - threads[i].join(); - } -} -// } -} // namespace paddle - -int main(int argc, char** argv) { - paddle::test_naive(1 << 0, ""); - return 0; -} diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index b1f97ddda5..2cd9979bd3 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define GLOG_NO_ABBREVIATED_SEVERITIES -#define GOOGLE_GLOG_DLL_DECL -#include #include "paddle/fluid/operators/conv_op.h" @@ -38,7 +35,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output(Output) of ConvOp should not be null."); - VLOG(3) << "Conv op infershape"; auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -46,51 +42,32 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); std::vector dilations = ctx->Attrs().Get>("dilations"); - VLOG(3) << "Conv op Before check"; - in_dims.size() == 4 || in_dims.size() == 5; - // PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, - // "Conv intput should be 4-D or 5-D tensor."); - VLOG(3) << "check0"; - - // PADDLE_ENFORCE_EQ( - // in_dims.size(), filter_dims.size(), - // "Conv input dimension and filter dimension should be the same."); - in_dims.size() == filter_dims.size(); - VLOG(3) << "enforce check0"; + + PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, + "Conv intput should be 4-D or 5-D tensor."); + PADDLE_ENFORCE_EQ( + in_dims.size(), filter_dims.size(), + "Conv input dimension and filter dimension should be the same."); PADDLE_ENFORCE( in_dims.size() - strides.size() == 2U, "Conv input dimension and strides dimension should be consistent."); - VLOG(3) << "check1"; PADDLE_ENFORCE_EQ( paddings.size(), strides.size(), "Conv paddings dimension and Conv strides dimension should be the same."); - VLOG(3) << "check2"; - // in_dims[1] == filter_dims[1] * groups; - // PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, - // "The number of input channels should be equal to filter " - // "channels * groups."); - VLOG(3) << "check3"; - // filter_dims[0] % groups == 0 ; - // PADDLE_ENFORCE_EQ( - // filter_dims[0] % groups, 0, - // "The number of output channels should be divided by groups."); - VLOG(3) << "filter" << filter_dims.size(); - VLOG(3) << "filter" << filter_dims[0]; - VLOG(3) << "check4"; - VLOG(3) << "filter" << filter_dims[1]; - VLOG(3) << "dims" << in_dims[0]; + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + filter_dims[0] % groups, 0, + "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); - VLOG(3) << "output shape"; for (size_t i = 0; i < strides.size(); ++i) { - VLOG(3) << "check5"; output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], paddings[i], strides[i])); - VLOG(3) << "check pass"; } - VLOG(3) << "Conv InferShape Pass"; ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); } -- GitLab