diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 53e71998f1517a8f2cfa4509426231fb0f0177e8..786f224608f7d41c438411de0e09fedbcf2264b8 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -1,11 +1,11 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -138,13 +138,14 @@ def main(): avg_cost = fluid.layers.mean(x=cost) # Evaluator - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + batch_size = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size) # inference program inference_program = fluid.default_main_program().clone() with fluid.program_guard(inference_program): - test_target = accuracy.metrics + accuracy.states - inference_program = fluid.io.get_inference_program(test_target) + inference_program = fluid.io.get_inference_program(batch_acc) # Optimization optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) @@ -157,27 +158,30 @@ def main(): # test def test(exe): - accuracy.reset(exe) + test_pass_acc = fluid.average.WeightedAverage() for batch_id, data in enumerate(test_reader()): img_data = np.array(map(lambda x: x[0].reshape(data_shape), data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - exe.run(inference_program, - feed={"pixel": img_data, - "label": y_data}) + outs = exe.run(inference_program, + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[batch_acc, batch_size]) + test_pass_acc.add(value=np.array(outs[0]), weight=np.array(outs[1])) - return accuracy.eval(exe) + return test_pass_acc.eval() def train_loop(exe, trainer_prog): iters = 0 ts = time.time() + train_pass_acc = fluid.average.WeightedAverage() for pass_id in range(args.num_passes): # train start_time = time.time() num_samples = 0 - accuracy.reset(exe) + train_pass_acc.reset() with profiler.profiler("CPU", 'total') as prof: for batch_id, data in enumerate(train_reader()): ts = time.time() @@ -187,13 +191,14 @@ def main(): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - loss, acc = exe.run( + loss, acc, b_size = exe.run( trainer_prog, feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost] + accuracy.metrics) + fetch_list=[avg_cost, batch_acc, batch_size]) iters += 1 num_samples += len(data) + train_pass_acc.add(value=acc, weight=b_size) print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" % (pass_id, iters, loss, acc, @@ -201,7 +206,7 @@ def main(): ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time - pass_train_acc = accuracy.eval(exe) + pass_train_acc = train_pass_acc.eval() pass_test_acc = test(exe) print( "Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n" diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index e2b7ef8d54748de46370e808e28f9b9fcc988a84..8af2765f58717408e3a1ef6b500bb01511bfd8d3 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -77,7 +77,8 @@ IF(NOT ${CBLAS_FOUND}) INSTALL_DIR ${CBLAS_INSTALL_DIR} BUILD_IN_SOURCE 1 BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS} - INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX= + INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX= + && rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig UPDATE_COMMAND "" CONFIGURE_COMMAND "" ) @@ -100,11 +101,6 @@ IF(NOT ${CBLAS_FOUND}) \"${CBLAS_INSTALL_DIR}/lib -> ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}\" )" ) - INSTALL(CODE "execute_process( - COMMAND rm -r ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}/cmake - ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}/pkgconfig - )" - ) ENDIF() ENDIF(NOT ${CBLAS_FOUND}) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 356da582d1f8b6a8858af90ccdf5af2100e5db87..d0b5eaec2e2a50acf17e5dd1d1aeb0ec3e614fbf 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -186,7 +186,9 @@ function(cc_library TARGET_NAME) add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) else() add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) + find_fluid_modules(${TARGET_NAME}) endif() + if(cc_library_DEPS) # Don't need link libwarpctc.so if("${cc_library_DEPS};" MATCHES "warpctc;") @@ -263,7 +265,8 @@ function(nv_library TARGET_NAME) if (nv_library_SHARED OR nv_library_shared) # build *.so cuda_add_library(${TARGET_NAME} SHARED ${nv_library_SRCS}) else() - cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) + cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) + find_fluid_modules(${TARGET_NAME}) endif() if (nv_library_DEPS) add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 4471df36b0717171da4dff92ca0ec98b4f981028..6b2237b858380f384be0aa3c6ae24a4c83ad646d 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -1,9 +1,22 @@ +set_property(GLOBAL PROPERTY FLUID_MODULES "") +# find all fluid modules is used for paddle fluid static library +function(find_fluid_modules TARGET_NAME) + get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) + string(FIND "${__target_path}" "fluid" pos) + if(pos GREATER 1) + get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) + set(fluid_modules ${fluid_modules} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY FLUID_MODULES "${fluid_modules}") + endif() +endfunction(find_fluid_modules) + # make package for paddle fluid shared and static library function(copy TARGET) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DSTS DEPS) cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(inference_lib_dist_dep ${TARGET} ${inference_lib_dist_dep} PARENT_SCOPE) list(LENGTH copy_lib_SRCS copy_lib_SRCS_len) list(LENGTH copy_lib_DSTS copy_lib_DSTS_len) @@ -42,13 +55,21 @@ copy(glog_lib DSTS ${dst_dir} ${dst_dir}/lib ) -IF(NOT PROTOBUF_FOUND) +if(NOT PROTOBUF_FOUND) set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/install/protobuf") copy(protobuf_lib - SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LITE_LIBRARY} + SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY} DSTS ${dst_dir} ${dst_dir}/lib ) -ENDIF(NOT PROTOBUF_FOUND) +endif() + +if(NOT CBLAS_FOUND) + set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/install/openblas") + copy(openblas_lib + SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include + DSTS ${dst_dir} ${dst_dir} + ) +endif() # paddle fluid module set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") @@ -66,8 +87,8 @@ copy(memory_lib ) set(module "inference") -copy(inference_lib DEPENDS paddle_fluid_shared - SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.so +copy(inference_lib DEPS paddle_fluid_shared paddle_fluid + SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* DSTS ${dst_dir}/${module} ${dst_dir}/${module} ) @@ -83,6 +104,4 @@ copy(string_lib DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat ) -add_custom_target(inference_lib_dist DEPENDS - inference_lib framework_lib memory_lib platform_lib string_lib - gflags_lib glog_lib protobuf_lib eigen3_lib) +add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep}) diff --git a/doc/fluid/howto/optimization/timeline.jpeg b/doc/fluid/howto/optimization/timeline.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..38ec3f80c982857531f30a8bb0fa26ea5bf05385 Binary files /dev/null and b/doc/fluid/howto/optimization/timeline.jpeg differ diff --git a/doc/fluid/howto/optimization/timeline.md b/doc/fluid/howto/optimization/timeline.md new file mode 100644 index 0000000000000000000000000000000000000000..9d9565a3e698a83ca465c5da83ff892360c33b8f --- /dev/null +++ b/doc/fluid/howto/optimization/timeline.md @@ -0,0 +1,27 @@ +## how to use timeline tool to do profile + +1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. **Warning**: Please do not run too many batches when use profiler to record timeline information, for the profile record will grow with the batch number. + + ```python + with profiler.profiler('All', 'total', '/tmp/profile') as prof: + for pass_id in range(pass_num): + for batch_id, data in enumerate(train_reader()): + exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[], + use_program_cache=True) + ... + ``` + +1. Run `python paddle/tools/timeline.py` to process `/tmp/profile`, it will generate another +file `/tmp/timeline` by default. You can change the path by cmd parameter, please take a look at +[timeline.py](https://github.com/PaddlePaddle/Paddle/blob/develop/tools/timeline.py) for details. + +1. Open chrome and visit , use `load` button to load the generated `timeline` file. + + ![chrome tracing](./tracing.jpeg) + +1. The resulting timeline should be like: + + + ![chrome timeline](./timeline.jpeg) diff --git a/doc/fluid/howto/optimization/tracing.jpeg b/doc/fluid/howto/optimization/tracing.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..3a49fc4f8a401a9463b0157e2f38c164ca02dcc5 Binary files /dev/null and b/doc/fluid/howto/optimization/tracing.jpeg differ diff --git a/doc/v2/build_and_install/pip_install_cn.rst b/doc/v2/build_and_install/pip_install_cn.rst index 8e4165da6b8135d083766c650f1092158f9d01c2..ddcd42a0c6554469d702d3a9bbecd16643d6b7ed 100644 --- a/doc/v2/build_and_install/pip_install_cn.rst +++ b/doc/v2/build_and_install/pip_install_cn.rst @@ -39,7 +39,7 @@ PaddlePaddle可以使用常用的Python包管理工具 "cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "暂无" - "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "暂无" + "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" diff --git a/doc/v2/build_and_install/pip_install_en.rst b/doc/v2/build_and_install/pip_install_en.rst index 0d4c925b6e2731bdfd76582309aa7e8adbafa6ae..e08c84703bfa89352a79acbddd5d7f1bc88ce82e 100644 --- a/doc/v2/build_and_install/pip_install_en.rst +++ b/doc/v2/build_and_install/pip_install_en.rst @@ -42,7 +42,7 @@ If the links below shows up the login form, just click "Log in as guest" to star "cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "Not Available" - "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "Not Available" + "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" "cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_" diff --git a/doc/v2/howto/index_cn.rst b/doc/v2/howto/index_cn.rst index 0c534f107b6e047035c424ed2ea59f3982799b63..b0268907bceb11cd53a4630c3f8b8e0424abe247 100644 --- a/doc/v2/howto/index_cn.rst +++ b/doc/v2/howto/index_cn.rst @@ -1,11 +1,37 @@ 进阶使用 ======== +PaddlePaddle支持用户灵活地设置各种命令行参数,以实现对模型训练或预测流程的控制。使用方式请参考: + .. toctree:: :maxdepth: 1 cmd_parameter/index_cn.rst + +PaddlePaddle支持在fabric集群、MPI集群、kubernetes集群上分布式训练任务,具体环境配置和使用说明请参考: + +.. toctree:: + :maxdepth: 1 + cluster/index_cn.rst + +PaddlePaddle提供了用于预测的C-API,关于C-API的使用,我们提供了如下指南: + +.. toctree:: + :maxdepth: 1 + capi/index_cn.rst + +PaddlePaddle支持多种灵活和高效的循环神经网络,具体配置使用方式请参考: + +.. toctree:: + :maxdepth: 1 + rnn/index_cn.rst + +关于如何使用内置的定时工具、nvprof 或 nvvp 来运行性能分析和调优,请参考: + +.. toctree:: + :maxdepth: 1 + optimization/gpu_profiling_cn.rst diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 554cd58916c5a1ba09a411b4dc0b3a834ccc486a..c0523f3c795b103c0c27081ec5dc717f6a0f11e0 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -53,6 +53,7 @@ struct CastDataType { auto* context = static_cast(ctx_); trans(*context, in_begin, in_end, out_begin, CastDataTypeFunctor()); + context->Wait(); #endif } else { PADDLE_THROW("Unsupported place!"); diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index c992cba9a3611d50839a8ec056ee6ab954cd88b6..6b9a8f5e28b372c45abfaa2c20575a55d9a9dd03 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -50,13 +50,13 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp32, kernel_fp64, in, &out); double* out_data_double = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast(i / 3)); + EXPECT_EQ(out_data_double[i], static_cast(i / 3)); } TransDataType(kernel_fp32, kernel_int32, in, &out); int* out_data_int = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast(i / 3)); + EXPECT_EQ(out_data_int[i], static_cast(i / 3)); } } @@ -76,31 +76,31 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp16, kernel_fp32, in, &out); float* out_data_float = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_float[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_float[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_fp64, in, &out); double* out_data_double = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_double[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_int32, in, &out); int* out_data_int = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_int[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_int64, in, &out); int64_t* out_data_int64 = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int64[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_int64[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_bool, in, &out); bool* out_data_bool = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_bool[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_bool[i], static_cast(ptr[i])); } // transform float to float16 @@ -112,7 +112,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp32, kernel_fp16, in, &out); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); } // transform double to float16 @@ -124,7 +124,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp64, kernel_fp16, in, &out); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); } // transform int to float16 @@ -136,7 +136,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_int32, kernel_fp16, in, &out); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); } // transform int64 to float16 @@ -148,7 +148,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_int64, kernel_fp16, in, &out); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); } // transform bool to float16 @@ -160,7 +160,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_bool, kernel_fp16, in, &out); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); } } } diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/paddle/fluid/framework/data_type_transform_test.cu index 3939bc5e754cd3c2829cfcb3353f83969af055a9..de389ddabcb86de0155757406a406e44086c5474 100644 --- a/paddle/fluid/framework/data_type_transform_test.cu +++ b/paddle/fluid/framework/data_type_transform_test.cu @@ -49,15 +49,16 @@ TEST(DataTypeTransform, GPUTransform) { float arr[6] = {0, 1, 2, 3, 4, 5}; int data_number = sizeof(arr) / sizeof(arr[0]); memcpy(in_ptr, arr, sizeof(arr)); - TensorCopy(in, gpu_place, context, &in_gpu); + TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_fp32, kernel_fp64, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); double* out_data_double = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast(arr[i])); + EXPECT_EQ(out_data_double[i], static_cast(arr[i])); } TransDataType(kernel_fp32, kernel_int32, in_gpu, &out_gpu); @@ -66,7 +67,7 @@ TEST(DataTypeTransform, GPUTransform) { int* out_data_int = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast(arr[i])); + EXPECT_EQ(out_data_int[i], static_cast(arr[i])); } } @@ -83,6 +84,7 @@ TEST(DataTypeTransform, GPUTransform) { int data_number = sizeof(arr) / sizeof(arr[0]); memcpy(ptr, arr, sizeof(arr)); TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); // transform from float16 to other data types TransDataType(kernel_fp16, kernel_fp32, in_gpu, &out_gpu); @@ -91,7 +93,7 @@ TEST(DataTypeTransform, GPUTransform) { float* out_data_float = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_float[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_float[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_fp64, in_gpu, &out_gpu); @@ -100,7 +102,7 @@ TEST(DataTypeTransform, GPUTransform) { double* out_data_double = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_double[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_int32, in_gpu, &out_gpu); @@ -109,7 +111,7 @@ TEST(DataTypeTransform, GPUTransform) { int* out_data_int = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_int[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_int64, in_gpu, &out_gpu); @@ -118,7 +120,7 @@ TEST(DataTypeTransform, GPUTransform) { int64_t* out_data_int64 = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int64[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_int64[i], static_cast(ptr[i])); } TransDataType(kernel_fp16, kernel_bool, in_gpu, &out_gpu); @@ -127,7 +129,7 @@ TEST(DataTypeTransform, GPUTransform) { bool* out_data_bool = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_bool[i], static_cast(ptr[i])); + EXPECT_EQ(out_data_bool[i], static_cast(ptr[i])); } // transform float to float16 @@ -137,13 +139,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_fp32, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_float[i]).x); } // transform double to float16 @@ -154,13 +157,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_fp64, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_double[i]).x); } // transform int to float16 @@ -170,13 +174,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_int32, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_int[i]).x); } // transform int64 to float16 @@ -187,13 +192,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_int64, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_int64[i]).x); } // transform bool to float16 @@ -203,13 +209,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_bool, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); + EXPECT_EQ(ptr[i].x, static_cast(in_data_bool[i]).x); } } } diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index bdb147955ca0700dc0854b54c38d961caf8845f3..17ccca8cdcbcaabaddbbc0ca1d3ca4fdf054b0fb 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -5,7 +5,8 @@ cc_library(paddle_fluid_api DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) # Create static library -cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) +get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) +cc_library(paddle_fluid DEPS ${fluid_modules}) # Create shared library cc_library(paddle_fluid_shared SHARED diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 80eb9889670744ae527ea29609b33631a021bfa8..52e9c0baa64508f82d0a86a88c8c5f8d23f9f7f2 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -22,14 +22,14 @@ namespace paddle { namespace inference { void ReadBinaryFile(const std::string& filename, std::string& contents) { - VLOG(3) << "loading model from " << filename; - std::ifstream inputfs(filename, std::ios::in | std::ios::binary); - inputfs.seekg(0, std::ios::end); + std::ifstream fin(filename, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); + fin.seekg(0, std::ios::end); contents.clear(); - contents.resize(inputfs.tellg()); - inputfs.seekg(0, std::ios::beg); - inputfs.read(&contents[0], contents.size()); - inputfs.close(); + contents.resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&contents[0], contents.size()); + fin.close(); } bool IsPersistable(const framework::VarDesc* var) { @@ -97,6 +97,7 @@ std::unique_ptr Load(framework::Executor& executor, const std::string& dirname) { std::string model_filename = dirname + "/__model__"; std::string program_desc_str; + VLOG(3) << "loading model from " << model_filename; ReadBinaryFile(model_filename, program_desc_str); std::unique_ptr main_program( diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index d6fc51301bc45d3ce01142c0f015e4bb69e519bf..e9a27171f1cd68e7b10c860fb4a1417b930ed565 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -17,10 +17,13 @@ limitations under the License. */ #include "paddle/fluid/inference/tests/test_helper.h" DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size of input data"); +DEFINE_int32(repeat, 1, "Running the inference program repeat times"); TEST(inference, image_classification) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " + "--batch_size=1 --repeat=1"; } LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; @@ -29,13 +32,11 @@ TEST(inference, image_classification) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - int64_t batch_size = 1; - paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [0.0, 1.0]. SetupTensor(input, - {batch_size, 3, 32, 32}, + {FLAGS_batch_size, 3, 32, 32}, static_cast(0), static_cast(1)); std::vector cpu_feeds; @@ -46,7 +47,9 @@ TEST(inference, image_classification) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference(dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << "--- CPU Runs: ---"; + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -55,7 +58,9 @@ TEST(inference, image_classification) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference(dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << "--- GPU Runs: ---"; + TestInference( + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); CheckError(output1, output2); diff --git a/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc b/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc index 99bee94cb82633219df7a4a5c5bada15d2d3ce64..1fb0f9e77797cf6e61e918700763ee33a495cb96 100644 --- a/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc @@ -17,10 +17,13 @@ limitations under the License. */ #include "paddle/fluid/inference/tests/test_helper.h" DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size of input data"); +DEFINE_int32(repeat, 1, "Running the inference program repeat times"); TEST(inference, recognize_digits) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " + "--batch_size=1 --repeat=1"; } LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; @@ -29,77 +32,39 @@ TEST(inference, recognize_digits) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - int64_t batch_size = 1; - paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [-1.0, 1.0]. SetupTensor(input, - {batch_size, 1, 28, 28}, + {FLAGS_batch_size, 1, 28, 28}, static_cast(-1), static_cast(1)); std::vector cpu_feeds; cpu_feeds.push_back(&input); - paddle::framework::LoDTensor output1; - std::vector cpu_fetchs1; - cpu_fetchs1.push_back(&output1); + for (auto is_combined : {false, true}) { + paddle::framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); - // Run inference on CPU - TestInference(dirname, cpu_feeds, cpu_fetchs1); - LOG(INFO) << output1.dims(); + // Run inference on CPU + LOG(INFO) << "--- CPU Runs: is_combined=" << is_combined << " ---"; + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined); + LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA - paddle::framework::LoDTensor output2; - std::vector cpu_fetchs2; - cpu_fetchs2.push_back(&output2); + paddle::framework::LoDTensor output2; + std::vector cpu_fetchs2; + cpu_fetchs2.push_back(&output2); - // Run inference on CUDA GPU - TestInference(dirname, cpu_feeds, cpu_fetchs2); - LOG(INFO) << output2.dims(); + // Run inference on CUDA GPU + LOG(INFO) << "--- GPU Runs: is_combined=" << is_combined << " ---"; + TestInference( + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined); + LOG(INFO) << output2.dims(); - CheckError(output1, output2); + CheckError(output1, output2); #endif -} - -TEST(inference, recognize_digits_combine) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; } - - LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::string dirname = FLAGS_dirname; - - // 0. Call `paddle::framework::InitDevices()` initialize all the devices - // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - - paddle::framework::LoDTensor input; - // Use normilized image pixels as input data, - // which should be in the range [-1.0, 1.0]. - SetupTensor( - input, {1, 1, 28, 28}, static_cast(-1), static_cast(1)); - std::vector cpu_feeds; - cpu_feeds.push_back(&input); - - paddle::framework::LoDTensor output1; - std::vector cpu_fetchs1; - cpu_fetchs1.push_back(&output1); - - // Run inference on CPU - TestInference( - dirname, cpu_feeds, cpu_fetchs1); - LOG(INFO) << output1.dims(); - -#ifdef PADDLE_WITH_CUDA - paddle::framework::LoDTensor output2; - std::vector cpu_fetchs2; - cpu_fetchs2.push_back(&output2); - - // Run inference on CUDA GPU - TestInference( - dirname, cpu_feeds, cpu_fetchs2); - LOG(INFO) << output2.dims(); - - CheckError(output1, output2); -#endif } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 49518e50d8541477234f17ac5b8709aeb57662ff..0f5fe6d0aa9a5522c67a3c06f8677f1f2f259eb3 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/profiler.h" template void SetupTensor(paddle::framework::LoDTensor& input, @@ -87,31 +88,60 @@ void CheckError(paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, - std::vector& cpu_fetchs) { + std::vector& cpu_fetchs, + const int repeat = 1, + const bool is_combined = false) { // 1. Define place, executor, scope auto place = Place(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); + // Profile the performance + paddle::platform::ProfilerState state; + if (paddle::platform::is_cpu_place(place)) { + state = paddle::platform::ProfilerState::kCPU; + } else { +#ifdef PADDLE_WITH_CUDA + state = paddle::platform::ProfilerState::kCUDA; + // The default device_id of paddle::platform::CUDAPlace is 0. + // Users can get the device_id using: + // int device_id = place.GetDeviceId(); + paddle::platform::SetDeviceId(0); +#else + PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); +#endif + } + + // Enable the profiler + paddle::platform::EnableProfiler(state); + // 2. Initialize the inference_program and load parameters std::unique_ptr inference_program; - if (IsCombined) { - // All parameters are saved in a single file. - // Hard-coding the file names of program and parameters in unittest. - // The file names should be consistent with that used in Python API - // `fluid.io.save_inference_model`. - std::string prog_filename = "__model_combined__"; - std::string param_filename = "__params_combined__"; - inference_program = paddle::inference::Load(executor, - *scope, - dirname + "/" + prog_filename, - dirname + "/" + param_filename); - } else { - // Parameters are saved in separate files sited in the specified `dirname`. - inference_program = paddle::inference::Load(executor, *scope, dirname); + { + paddle::platform::RecordEvent record_event( + "init_program", + paddle::platform::DeviceContextPool::Instance().Get(place)); + + if (is_combined) { + // All parameters are saved in a single file. + // Hard-coding the file names of program and parameters in unittest. + // The file names should be consistent with that used in Python API + // `fluid.io.save_inference_model`. + std::string prog_filename = "__model_combined__"; + std::string param_filename = "__params_combined__"; + inference_program = + paddle::inference::Load(executor, + *scope, + dirname + "/" + prog_filename, + dirname + "/" + param_filename); + } else { + // Parameters are saved in separate files sited in the specified + // `dirname`. + inference_program = paddle::inference::Load(executor, *scope, dirname); + } } // 3. Get the feed_target_names and fetch_target_names @@ -134,7 +164,21 @@ void TestInference(const std::string& dirname, } // 6. Run the inference program - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + { + // Run repeat times to profile the performance + for (int i = 0; i < repeat; ++i) { + paddle::platform::RecordEvent record_event( + "run_inference", + paddle::platform::DeviceContextPool::Instance().Get(place)); + + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + } + } + + // Disable the profiler and print the timing information + paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, + "profiler.txt"); + paddle::platform::ResetProfiler(); delete scope; } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 62f00ab612a7409094fec443410de6f598840318..5d436a7e0c3752c889c19820507589f34d3bee94 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -1,5 +1,7 @@ file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") +list(REMOVE_DUPLICATES GENERAL_OPS) set(DEPS_OPS "") set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/pybind.h) file(WRITE ${pybind_file} "// Generated by the paddle/operator/CMakeLists.txt. DO NOT EDIT!\n\n") @@ -13,6 +15,8 @@ function(op_library TARGET) set(cu_cc_srcs) set(cudnn_cu_cc_srcs) set(CUDNN_FILE) + set(mkldnn_cc_srcs) + set(MKLDNN_FILE) set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") @@ -36,12 +40,20 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) endif() + if(WITH_MKLDNN) + string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) + list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc) + endif() + endif() else() foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") + list(APPEND mkldnn_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cu.cc$") list(APPEND cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") @@ -62,11 +74,11 @@ function(op_library TARGET) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() - cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} - ${op_common_deps}) + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) endif() # Define operators that don't need pybind here. @@ -101,7 +113,8 @@ function(op_library TARGET) # pybind USE_CPU_ONLY_OP list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_cc_srcs cu_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -112,6 +125,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") endif() + # pybind USE_OP_DEVICE_KERNEL for MKLDNN + if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() + # pybind USE_OP if (${pybind_flag} EQUAL 0) file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") @@ -172,17 +190,18 @@ op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) if (WITH_GPU) - op_library(conv_op DEPS vol2col depthwise_conv) + op_library(conv_op DEPS vol2col depthwise_conv im2col) else() - op_library(conv_op DEPS vol2col) + op_library(conv_op DEPS vol2col im2col) endif() -op_library(conv_transpose_op DEPS vol2col) +op_library(conv_transpose_op DEPS vol2col im2col) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) +op_library(concat_op DEPS concat) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index bdce8f0a6f5f817d2010e8b4de28a7d1be4cc08b..0eedd8ee51ebfff6f553d8e19e97c3a45a95fa6a 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad, ops::ConcatOpGrad, false) -REGISTER_OP_CPU_KERNEL(concat, - ops::ConcatKernel) -REGISTER_OP_CPU_KERNEL(concat_grad, - ops::ConcatGradKernel) +REGISTER_OP_CPU_KERNEL( + concat, ops::ConcatKernel) +REGISTER_OP_CPU_KERNEL( + concat_grad, + ops::ConcatGradKernel) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 208a4481c6afe1b8f62e8f675c951c3349639f46..92c8ab6d9ff11ec6acd46a39877eb67d624748a9 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { @@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + framework::Tensor* out = ctx.Output("Out"); int64_t axis = static_cast(ctx.Attr("axis")); auto place = ctx.GetPlace(); out->mutable_data(place); - auto out_stride = framework::stride_numel(out->dims()); - - size_t output_offset = 0; - - // If axis >=1, copy to out immediately need to call many times - // of cuda memcpy. Copy the input to cpu and do the stride copy, - // then copy to gpu output. - - if (platform::is_gpu_place(place) && axis >= 1) { - platform::CPUPlace copy_place; - auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place); - framework::Tensor cpu_out; - cpu_out.Resize(out->dims()); - cpu_out.mutable_data(copy_place); - auto& dev_ctx = ctx.device_context(); - std::vector> cpu_ins; - for (auto* in : ins) { - std::unique_ptr cpu_in(new framework::Tensor); - framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get()); - cpu_ins.emplace_back(std::move(cpu_in)); - } - // TODO(dzhwinter): overlap copy and compute stream - // https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/ - dev_ctx.Wait(); - - for (auto& in : cpu_ins) { - auto& cpu_in = *in.get(); - auto in_stride = framework::stride_numel(cpu_in.dims()); - - StridedNumelCopyWithAxis( - cpu_ctx, axis, cpu_out.data() + output_offset, out_stride, - cpu_in.data(), in_stride, in_stride[axis]); - output_offset += in_stride[axis]; - } - framework::TensorCopy(cpu_out, place, dev_ctx, out); - } else { + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && ins.size() < 10) { + size_t output_offset = 0; for (auto* in : ins) { auto in_stride = framework::stride_numel(in->dims()); + auto out_stride = framework::stride_numel(out->dims()); StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data() + output_offset, out_stride, in->data(), in_stride, in_stride[axis]); output_offset += in_stride[axis]; } + } else { + std::vector inputs(ins.size()); + for (size_t j = 0; j < ins.size(); ++j) { + inputs[j] = *ins[j]; + } + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatFunctor concat_functor; + concat_functor(dev_ctx, inputs, static_cast(axis), out); } } }; @@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel { auto* in = ctx.Input(framework::GradVarName("Out")); auto outs = ctx.MultiOutput(framework::GradVarName("X")); int64_t axis = static_cast(ctx.Attr("axis")); - size_t input_offset = 0; - auto in_stride = framework::stride_numel(in->dims()); - for (auto& out : outs) { - out->mutable_data(ctx.GetPlace()); - auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), - out_stride, in->data() + input_offset, - in_stride, out_stride[axis]); - input_offset += out_stride[axis]; + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + size_t input_offset = 0; + auto in_stride = framework::stride_numel(in->dims()); + + for (auto& out : outs) { + out->mutable_data(ctx.GetPlace()); + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride, out_stride[axis]); + input_offset += out_stride[axis]; + } + } else { + std::vector outputs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data(ctx.GetPlace()); + outputs[j] = *outs[j]; + } + + auto& dev_ctx = ctx.template device_context(); + paddle::operators::math::ConcatGradFunctor + concat_grad_functor; + concat_grad_functor(dev_ctx, *in, static_cast(axis), outputs); } } }; diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d59cc2c9d424f067ca638cb76e52c2e95ae75182 --- /dev/null +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -0,0 +1,313 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNMemDesc; + +using mkldnn::memory; // Note: paddle has also "memory" namespace +using mkldnn::primitive; +using mkldnn::convolution_forward; +using mkldnn::convolution_backward_weights; +using mkldnn::convolution_backward_data; +using mkldnn::convolution_direct; +using mkldnn::prop_kind; +using mkldnn::padding_kind; +using mkldnn::stream; + +namespace { +std::unique_ptr +ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine); + +convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc( + const memory::desc& src, const memory::desc& diff_weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine); + +convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc( + const memory::desc& diff_src, const memory::desc& weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine); +} // anonymous namespace + +template +class ConvOpMkldnnKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + // Get an unique name from "argument" name of "Output" variable + // This name will be used as key when saving info into device context + const std::string key = ctx.op().Output("Output"); + const std::string key_conv_pd = key + "@conv_pd"; + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + + // TODO(pzelazko-intel) add support for group convolution and dilation + PADDLE_ENFORCE(groups == 1, "group convolution is not implemented yet"); + PADDLE_ENFORCE( + dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, + "dilation in convolution is not implemented yet"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + // allocate memory for output + T* output_data = output->mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE(input->dims().size() == 4, + "Input must be with 4 dimensions, i.e. NCHW"); + PADDLE_ENFORCE(filter->dims().size() == 4, + "Filter must be with 4 dimensions, i.e. OIHW"); + + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector weights_tz = + paddle::framework::vectorize2int(filter->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + // TODO(pzelazko-intel): support more formats + // memory descriptors for convolution src/weight/dst + auto conv_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_dst_md = + MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw); + + // create memory primitives + auto conv_src_memory = + memory({conv_src_md, mkldnn_engine}, (void*)input_data); + auto conv_weights_memory = + memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); + auto conv_dst_memory = memory({conv_dst_md, mkldnn_engine}, output_data); + + std::unique_ptr conv_pd = + ConvFwdPrimitiveDesc(conv_src_md, conv_weights_md, conv_dst_md, strides, + paddings, mkldnn_engine); + + // save p_conv_pd into dev_ctx to be referred in backward path + auto p_conv_pd = conv_pd.get(); + std::shared_ptr conv_pd_value = std::move(conv_pd); + dev_ctx.SetBlob(key_conv_pd, conv_pd_value); + + // create convolution op primitive + auto conv_prim = convolution_forward(*p_conv_pd, conv_src_memory, + conv_weights_memory, conv_dst_memory); + + // push op to stream and wait MKLDNN until it's executed + std::vector pipeline{conv_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } +}; + +template +class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const Tensor* input = ctx.Input("Input"); + const Tensor* filter = ctx.Input("Filter"); + const Tensor* output = ctx.Input("Output"); + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + + if (!input_grad && !filter_grad) return; + + // Get an unique name from "argument" name of "Output" variable + // This name will be used as key when saving info into device context + const std::string key = ctx.op().Input("Output"); + const std::string key_conv_pd = key + "@conv_pd"; + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = nullptr; + T* filter_grad_data = nullptr; + + // allocate memory for gradient of input/filter + if (input_grad) { + input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + } + if (filter_grad) { + filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + } + + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector weights_tz = + paddle::framework::vectorize2int(filter->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + // TODO(pzelazko-intel): support more formats + auto conv_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_diff_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_diff_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_diff_dst_md = + MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw); + + // create memory + auto conv_diff_dst_memory = + memory({conv_diff_weights_md, mkldnn_engine}, (void*)output_grad_data); + // Retrieve conv_pd from device context + std::shared_ptr conv_pd; + convolution_forward::primitive_desc* p_conv_pd; + + conv_pd = dev_ctx.GetBlob(key_conv_pd); + PADDLE_ENFORCE(conv_pd != nullptr, + "Fail to find conv_pd in device context"); + p_conv_pd = + static_cast(conv_pd.get()); + + // create backward conv primitive for weights + if (filter_grad) { + // create primitive descriptor + convolution_backward_weights::primitive_desc conv_bwd_weights_pd = + ConvBwdWeightsPrimitiveDesc(conv_src_md, conv_diff_weights_md, + conv_diff_dst_md, strides, paddings, + *p_conv_pd, mkldnn_engine); + + // create memory + auto conv_diff_weights_memory = memory( + {conv_diff_weights_md, mkldnn_engine}, (void*)filter_grad_data); + auto conv_src_memory = + memory({conv_src_md, mkldnn_engine}, (void*)input_data); + + // create backward conv primitive for weights + auto conv_bwd_weights_prim = convolution_backward_weights( + conv_bwd_weights_pd, conv_src_memory, conv_diff_dst_memory, + conv_diff_weights_memory); + + // push primitive and execute it + std::vector pipeline{conv_bwd_weights_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } + + if (input_grad) { + // create primitive descriptor + convolution_backward_data::primitive_desc conv_bwd_data_pd = + ConvBwdDataPrimitiveDesc(conv_diff_src_md, conv_weights_md, + conv_diff_dst_md, strides, paddings, + *p_conv_pd, mkldnn_engine); + + // create memory + auto conv_diff_src_memory = + memory({conv_diff_src_md, mkldnn_engine}, (void*)input_grad_data); + auto conv_weights_memory = + memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); + + // create backward conv primitive for data + auto conv_bwd_data_prim = + convolution_backward_data(conv_bwd_data_pd, conv_diff_dst_memory, + conv_weights_memory, conv_diff_src_memory); + + // push primitive and execute it + std::vector pipeline{conv_bwd_data_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } + } // Compute() +}; + +namespace { +std::unique_ptr ConvFwdPrimitiveDesc( + const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, const mkldnn::engine& engine) { + mkldnn::memory::dims stride_dims = {strides[0], strides[1]}; + mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, dst, + stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); + + auto p_conv_pd = new convolution_forward::primitive_desc(conv_desc, engine); + + return std::unique_ptr( + p_conv_pd); +} + +convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc( + const memory::desc& src, const memory::desc& diff_weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine) { + auto conv_bwd_weights_desc = convolution_backward_weights::desc( + convolution_direct, src, diff_weights, diff_dst, strides, paddings, + paddings, padding_kind::zero); + return convolution_backward_weights::primitive_desc(conv_bwd_weights_desc, + engine, conv_pd); +} + +convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc( + const memory::desc& diff_src, const memory::desc& weights, + const memory::desc& diff_dst, const std::vector& strides, + const std::vector& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine) { + auto conv_bwd_data_desc = convolution_backward_data::desc( + convolution_direct, diff_src, weights, diff_dst, strides, paddings, + paddings, padding_kind::zero); + return convolution_backward_data::primitive_desc(conv_bwd_data_desc, engine, + conv_pd); +} +} // anonymous namespace +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConvOpMkldnnKernel); + +REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConvGradOpMkldnnKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 83b7708bf337b70f97c5e9126efd142b9b957b00..4b02b80d7772fa15d2333692551da5e59d93765f 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -64,22 +70,21 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } #endif - framework::LibraryType library_; - if (use_cudnn) { - library_ = framework::LibraryType::kCUDNN; - } else { - library_ = framework::LibraryType::kPlain; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } +#endif std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), @@ -131,6 +136,9 @@ Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " @@ -224,6 +232,9 @@ Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker) "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " @@ -284,23 +295,21 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } #endif - - framework::LibraryType library_; - if (use_cudnn) { - library_ = framework::LibraryType::kCUDNN; - } else { - library_ = framework::LibraryType::kPlain; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } +#endif std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index 9b8ca925373eb99df3fcb14ec7962768dd2eab04..73c84c2fe0155d21d7059938330e44fa3668c6df 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -71,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { return framework::OpKernelType( framework::ToDataType( ctx.Input("DetectRes")->type()), - ctx.device_context()); + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise_div_op.h index 6bcc577456b13f7930e07c24564953e93c5339ed..95649ac46e6bd41b9e1a865794cdec3ae1e6e247 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise_div_op.h @@ -41,77 +41,14 @@ class ElementwiseDivKernel : public framework::OpKernel { }; template -struct ElementwiseDivGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto y_e = framework::EigenVector::Flatten(*y); - auto z_e = framework::EigenVector::Flatten(*z); - auto dz_e = framework::EigenVector::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e / y_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = -1.0 * dz_e * z_e / y_e; - } - } -}; - -template -struct ElementwiseDivBroadCastGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) - .broadcast(Eigen::DSizes(pre, 1)) - .reshape(Eigen::DSizes(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e / y_e_bcast; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast)) - .reshape(Eigen::DSizes(pre, n)) - .sum(Eigen::array{{0}}); - } - } +struct DivGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } }; template -struct ElementwiseDivBroadCast2GradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) - .broadcast(Eigen::DSizes(pre, 1, post)) - .reshape(Eigen::DSizes(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e / y_e_bcast; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast)) - .reshape(Eigen::DSizes(pre, n, post)) - .sum(Eigen::array{{0, 2}}); - } +struct DivGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return -dout * x / (y * y); } }; @@ -128,10 +65,8 @@ class ElementwiseDivGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElementwiseGradCompute, - ElementwiseDivBroadCastGradFunctor, - ElementwiseDivBroadCast2GradFunctor>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute, DivGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), DivGradDY()); } }; diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise_max_op.h index ab3a3d58275e8a48a92aa0c8a50af072e83ac968..527a18ee3ba88a158a13266a7fbcdafe59ec69d9 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise_max_op.h @@ -41,76 +41,16 @@ class ElementwiseMaxKernel : public framework::OpKernel { }; template -struct ElementwiseMaxGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e > y_e).template cast() * dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (x_e <= y_e).template cast() * dz_e; - } +struct MaxGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x > y); } }; template -struct ElementwiseMaxBroadCastGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) - .broadcast(Eigen::DSizes(pre, 1)) - .reshape(Eigen::DSizes(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = ((x_e <= y_e_bcast).template cast() * dz_e) - .reshape(Eigen::DSizes(pre, n)) - .sum(Eigen::array{{0}}); - } - } -}; - -template -struct ElementwiseMaxBroadCast2GradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) - .broadcast(Eigen::DSizes(pre, 1, post)) - .reshape(Eigen::DSizes(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = ((x_e <= y_e_bcast).template cast() * dz_e) - .reshape(Eigen::DSizes(pre, n, post)) - .sum(Eigen::array{{0, 2}}); - } +struct MaxGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x <= y); } }; @@ -127,12 +67,9 @@ class ElementwiseMaxGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElementwiseGradCompute, - ElementwiseMaxBroadCastGradFunctor, - ElementwiseMaxBroadCast2GradFunctor>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute, MaxGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); } }; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise_min_op.h index f0eec9d2468b519473f813d71127d670adb08250..d4e5831463f3e54c72789b6876ea696cf1b4ef4b 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise_min_op.h @@ -41,76 +41,16 @@ class ElementwiseMinKernel : public framework::OpKernel { }; template -struct ElementwiseMinGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e < y_e).template cast() * dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (x_e >= y_e).template cast() * dz_e; - } +struct MinGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x < y); } }; template -struct ElementwiseMinBroadCastGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) - .broadcast(Eigen::DSizes(pre, 1)) - .reshape(Eigen::DSizes(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e < y_e_bcast).template cast() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = ((x_e >= y_e_bcast).template cast() * dz_e) - .reshape(Eigen::DSizes(pre, n)) - .sum(Eigen::array{{0}}); - } - } -}; - -template -struct ElementwiseMinBroadCast2GradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector::Flatten(*x); - auto y_e = framework::EigenVector::Flatten(*y); - auto dz_e = framework::EigenVector::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) - .broadcast(Eigen::DSizes(pre, 1, post)) - .reshape(Eigen::DSizes(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e < y_e_bcast).template cast() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = ((x_e >= y_e_bcast).template cast() * dz_e) - .reshape(Eigen::DSizes(pre, n, post)) - .sum(Eigen::array{{0, 2}}); - } +struct MinGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x >= y); } }; @@ -127,12 +67,9 @@ class ElementwiseMinGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElementwiseGradCompute, - ElementwiseMinBroadCastGradFunctor, - ElementwiseMinBroadCast2GradFunctor>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute, MinGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx(), MinGradDy()); } }; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index e2b59b31120964d49d2de29ed327ffe0e6d44ddc..dc73cb6f23614504640283af01981d3f69e89126 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -40,14 +40,15 @@ class ElementwiseMulKernel : public framework::OpKernel { }; template -struct IdentityGrad_DX { +struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; template -struct IdentityGrad_DY { +struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; + template class ElementwiseMulGradKernel : public framework::OpKernel { public: @@ -61,10 +62,8 @@ class ElementwiseMulGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElemwiseGradCompute, - IdentityGrad_DY>(ctx, *x, *y, *out, *dout, axis, dx, - dy, IdentityGrad_DX(), - IdentityGrad_DY()); + ElemwiseGradCompute, MulGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); } }; } // namespace operators diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h index a8fc242ed79ecee77dc5bb56882a1a15f872fccf..fe088b8203722a43b9aba7be3878b8f4ca68ba12 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise_sub_op.h @@ -40,61 +40,13 @@ class ElementwiseSubKernel : public framework::OpKernel { }; template -struct ElementwiseSubGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto dz_e = framework::EigenVector::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0) * dz_e; - } - } +struct SubGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; template -struct ElementwiseSubBroadCastGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto dz_e = framework::EigenVector::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0) * - dz_e.reshape(Eigen::DSizes(pre, n)) - .sum(Eigen::array{{0}}); - } - } -}; - -template -struct ElementwiseSubBroadCast2GradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto dz_e = framework::EigenVector::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0) * - dz_e.reshape(Eigen::DSizes(pre, n, post)) - .sum(Eigen::array{{0, 2}}); - } - } +struct SubGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } }; template @@ -110,12 +62,9 @@ class ElementwiseSubGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElementwiseGradCompute, - ElementwiseSubBroadCastGradFunctor, - ElementwiseSubBroadCast2GradFunctor>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute, SubGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), SubGradDY()); } }; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 768106fadf355ea6fb148491e232dc0ef1453a75..a181d802262d15b188060dae4330cec0e24714ab 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -1,46 +1,59 @@ add_subdirectory(detail) -if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context framework_proto) - nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) - nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function) - nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) - nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context) - nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) - nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) - nv_library(depthwise_conv SRCS depthwise_conv.cu DEPS device_context) - nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) - nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) - nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) - nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor math_function) - nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context) - nv_library(sequence_scale SRCS sequence_scale.cc sequence_scale.cu DEPS lod_tensor device_context) - nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) - nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) - nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) - nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) - nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) -else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) - cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) - cc_library(softmax SRCS softmax.cc DEPS device_context) - cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context) - cc_library(pooling SRCS pooling.cc DEPS device_context) - cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) - cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor) - cc_library(context_project SRCS context_project.cc DEPS device_context math_function) - cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor math_function) - cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context) - cc_library(sequence_scale SRCS sequence_scale.cc DEPS lod_tensor device_context) - cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) - cc_library(maxouting SRCS maxouting.cc DEPS device_context) - cc_library(unpooling SRCS unpooling.cc DEPS device_context) - cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) - cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) -endif() +function(math_library TARGET) + # math_library is a function to create math library. + # The interface is the same as cc_library. + # But it handle split GPU/CPU code and link some common library. + set(cc_srcs) + set(cu_srcs) + set(math_common_deps device_context framework_proto) + set(multiValueArgs DEPS) + cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + list(APPEND cc_srcs ${TARGET}.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) + list(APPEND cu_srcs ${TARGET}.cu) + endif() + + list(LENGTH cc_srcs cc_srcs_len) + if (WITH_GPU) + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + elseif(${cc_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + endif() +endfunction() -cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) +# please add new math_library in alphabetical order +math_library(concat) +math_library(context_project DEPS im2col math_function) +math_library(cross_entropy) +math_library(cos_sim_functor) +math_library(depthwise_conv) +math_library(gru_compute DEPS activation_functions math_function) +math_library(im2col) +math_library(lstm_compute DEPS activation_functions) +math_library(math_function DEPS cblas) +math_library(maxouting) +math_library(pooling) +math_library(selected_rows_functor DEPS selected_rows) +math_library(sequence2batch) +math_library(sequence_padding) +math_library(sequence_pooling DEPS math_function) +math_library(sequence_scale) +math_library(softmax) +math_library(unpooling) +math_library(vol2col) + +cc_test(math_function_test SRCS math_function_test.cc) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) -cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) -cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) +cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) +if(WITH_GPU) + nv_test(math_function_gpu_test SRCS math_function_test.cu) + nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) +endif() +cc_test(concat_test SRCS concat_test.cc DEPS concat) diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..b542143419e05e9baf29e9a2322447f32ddd9829 --- /dev/null +++ b/paddle/fluid/operators/math/concat.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + auto& cpu_place = boost::get(context.GetPlace()); + + // computation + for (int k = 0; k < out_rows; ++k) { + T* dst_ptr = output->data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + const T* src_prt = input[j].data() + k * col_len; + memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, + sizeof(T) * col_len); + col_idx += col_len; + } + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // TODO(zcd): Add input data validity checking + int num = outputs.size(); + + int input_rows = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + int input_cols = 0; + + std::vector output_cols(outputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = outputs[i].numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto& cpu_place = boost::get(context.GetPlace()); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data() + k * input_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = output_cols[j]; + T* dst_ptr = outputs[j].data() + k * col_len; + memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, + sizeof(T) * col_len); + col_idx += col_len; + } + } + } +}; + +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; + +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu new file mode 100644 index 0000000000000000000000000000000000000000..60b266f08fb2d4217c5933902d69de96fc2abe22 --- /dev/null +++ b/paddle/fluid/operators/math/concat.cu @@ -0,0 +1,280 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__device__ T upper_bound(const T* first, T count, T val) { + const T* orig = first; + const T* it = nullptr; + T step = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (!(val < *it)) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first - orig; +} + +template +__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(input_cols, col_size, tid_x) - 1; + + int curr_offset = input_cols[segment]; + int curr_segment = segment; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__global__ void KernelConcat(T** inputs, const int input_col, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* input_ptr = inputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * input_col + in_offset]; + } + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int* output_cols, + int col_size, T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(output_cols, col_size, tid_x) - 1; + int curr_offset = output_cols[segment]; + int curr_segment = segment; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input[tid_y * input_col + tid_x]; + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int output_cols, + T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* output_ptr = outputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * output_cols + in_offset] = + input[tid_y * input_col + tid_x]; + } +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int cols = input[0].numel() / rows; + int out_rows = rows, out_cols = 0; + + framework::Vector inputs_data(num * sizeof(T*) / 2); + framework::Vector inputs_cols(num + 1); + inputs_cols[0] = 0; + T** inputs_ptr = reinterpret_cast(inputs_data.data()); + + bool sameShape = true; + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + if (sameShape) { + if (t_cols != cols) sameShape = false; + } + out_cols += t_cols; + inputs_cols[i + 1] = out_cols; + inputs_ptr[i] = const_cast(input[i].data()); + } + + T** ins_gpu = + reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); + const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); + + // computation + // set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((out_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + KernelConcat<<>>( + ins_gpu, cols, out_rows, out_cols, output->data()); + } else { + KernelConcat<<>>( + ins_gpu, ins_col_gpu, static_cast(inputs_cols.size()), out_rows, + out_cols, output->data()); + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // TODO(zcd): Add input data validity checking + int num = outputs.size(); + int input_row = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_row *= dim_0[i]; + } + + int output_col_0 = outputs[0].numel() / input_row; + int input_col = 0; + bool sameShape = true; + + framework::Vector outputs_data(num * sizeof(T*) / 2); + framework::Vector outputs_cols(num + 1); + outputs_cols[0] = 0; + T** outputs_ptr = reinterpret_cast(outputs_data.data()); + + for (int i = 0; i < num; ++i) { + int t_col = outputs[i].numel() / input_row; + if (sameShape) { + if (t_col != output_col_0) sameShape = false; + } + input_col += t_col; + outputs_cols[i + 1] = input_col; + outputs_ptr[i] = outputs[i].data(); + } + + T** outs_gpu = + reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); + const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); + + // computation + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((input_col + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((input_col + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + KernelConcatGrad<<>>( + input.data(), input_row, input_col, output_col_0, outs_gpu); + } else { + KernelConcatGrad<<>>( + input.data(), input_row, input_col, outs_col_gpu, + static_cast(outputs_cols.size()), outs_gpu); + } + } +}; + +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; + +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h new file mode 100644 index 0000000000000000000000000000000000000000..22147d79e4b1eeee76f7445dd963bf5062049a34 --- /dev/null +++ b/paddle/fluid/operators/math/concat.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * \brief Concatenate the input tensors along the dimension axis. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input[0] = [[1,2],[3,4]] + * Input[1] = [[5,6]] + * axis = 0 + * + * Output = [[1,2], + * [3,4], + * [5,6]] + */ +template +class ConcatFunctor { + public: + void operator()(const DeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output); +}; + +/* + * \brief Split the input tensors along the dimension axis into outputs. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input = [[1,2], + * [3,4], + * [5,6]] + * axis = 0 + * + * Output[0] = [[1,2],[3,4]] + * Output[1] = [[5,6]] + */ +template +class ConcatGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const int axis, std::vector& outputs); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1741af8148bb90863f294ba4930006a58b5ddbf9 --- /dev/null +++ b/paddle/fluid/operators/math/concat_test.cc @@ -0,0 +1,336 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" +#include +#include +#include "paddle/fluid/framework/tensor_util.h" + +using namespace paddle::framework; +using namespace paddle::platform; + +template +void testConcat() { + Tensor input_a_cpu; + Tensor input_b_cpu; + Tensor out_cpu; + Tensor input_a; + Tensor input_b; + Tensor out; + + DeviceContext* context = new DeviceContext(Place()); + // DeviceContext context(Place()); + + /** + * cast1: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [3, 3, 4] + * output: + * out.shape: [5, 3, 4] + */ + auto dim_a = make_ddim({2, 3, 4}); + auto dim_b = make_ddim({3, 3, 4}); + auto dim_out = make_ddim({5, 3, 4}); + + input_a.mutable_data(dim_a, Place()); + input_b.mutable_data(dim_b, Place()); + out.mutable_data(dim_out, Place()); + + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.mutable_data(dim_a, CPUPlace()); + input_b_cpu.mutable_data(dim_b, CPUPlace()); + out_cpu.mutable_data(dim_out, CPUPlace()); + } + + int* a_ptr; + int* b_ptr; + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 3 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + std::vector input; + input.push_back(input_a); + input.push_back(input_b); + + paddle::operators::math::ConcatFunctor concat_functor; + concat_functor(*context, input, 0, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + int* out_ptr; + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + int cols = 2 * 3 * 4; + int idx_a = 0, idx_b = 0; + for (int j = 0; j < 5 * 3 * 4; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a]); + ++idx_a; + } + } + // + /** + * cast2: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 4, 4] + * output: + * out.shape: [2, 7, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 4, 4}); + dim_out = make_ddim({2, 7, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 4 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + cols = 3 * 4; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 28; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } + + /** + * cast3: + * inputs: + * t_a.shape: [2, 3, 5] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 3, 9] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 5}); + dim_out = make_ddim({2, 3, 9}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 5; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 2, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + // check the data + cols = 4; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 6; ++i) { + for (int j = 0; j < 9; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } + + /** + * cast4: + * inputs: + * axis = 1 + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 6, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 4}); + dim_out = make_ddim({2, 6, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data(); + b_ptr = input_b_cpu.data(); + } else { + a_ptr = input_a.data(); + b_ptr = input_b.data(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data(); + } else { + out_ptr = out.data(); + } + + // check the data + cols = 12; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 24; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } +} + +TEST(math, concat) { + testConcat(); +#ifdef PADDLE_WITH_CUDA + testConcat(); +#endif +} diff --git a/paddle/fluid/operators/math/sequence2batch.cc b/paddle/fluid/operators/math/sequence2batch.cc index 72bf2ab17016157f22bcf180090660346ed4f60d..8899abff360ea867872d3433722cdb37ef358500 100644 --- a/paddle/fluid/operators/math/sequence2batch.cc +++ b/paddle/fluid/operators/math/sequence2batch.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 3580932356fd5f29d5e4d00a70e64c207c64e41e..832509641cc3d5178ff090e05437484d395bfe51 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -84,6 +84,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("shape", "(vector) " "Target shape of reshape operator."); + AddAttr("inplace", + "Change the source tensor's shape without copy memory.") + .SetDefault(true); AddComment(R"DOC( Reshape Operator. diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index 1357bce4b7e7dcde616ec15df6617db5573df1d3..eacb0a0cf21a60ffbdef5787434859ac549388bc 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -26,10 +26,16 @@ class ReshapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); + bool inplace = ctx.Attr("inplace"); auto out_dims = out->dims(); - out->mutable_data(ctx.GetPlace()); - framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); - out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace()); + framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } } }; @@ -40,10 +46,16 @@ class ReshapeGradKernel : public framework::OpKernel { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto* d_x = ctx.Output(framework::GradVarName("X")); d_x->mutable_data(ctx.GetPlace()); + bool inplace = ctx.Attr("inplace"); auto in_dims = d_x->dims(); - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - d_x->Resize(in_dims); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } } }; } // namespace operators diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 48c967de1155a705b76f0285a4baed52d7ce1f58..1842ecd745e3f5cb75600ce00d89018f81682632 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include + +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" @@ -282,5 +284,17 @@ class ScopedPoolingDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); }; +inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (use_cudnn) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + return use_cudnn; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7da6e04d0a8b81bcb5fb6b105ebdd5b908cf8f1d..bb9fbd468f38fffc94107e321e777fc0e772fbe6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -33,9 +33,15 @@ DeviceContextPool::DeviceContextPool( PADDLE_ENFORCE_GT(places.size(), 0); for (size_t i = 0; i < places.size(); i++) { if (platform::is_cpu_place(places[i])) { +#ifdef PADDLE_WITH_MKLDNN + device_contexts_.emplace(places[i], + new platform::MKLDNNDeviceContext( + boost::get(places[i]))); +#else device_contexts_.emplace(places[i], new platform::CPUDeviceContext( boost::get(places[i]))); +#endif } else if (platform::is_gpu_place(places[i])) { #ifdef PADDLE_WITH_CUDA device_contexts_.emplace(places[i], @@ -121,6 +127,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { SetDeviceId(place_.device); + multi_process = GetCUDAMultiProcessors(place_.device); + max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -154,6 +162,10 @@ void CUDADeviceContext::Wait() const { PADDLE_ENFORCE(cudaGetLastError()); } +int CUDADeviceContext::GetMaxPhysicalThreadCount() const { + return multi_process * max_threads_per_mp; +} + Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } @@ -170,64 +182,38 @@ cudaStream_t CUDADeviceContext::stream() const { return stream_; } #ifdef PADDLE_WITH_MKLDNN MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) - : CPUDeviceContext(place), ready_(false) { - stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); - engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0)); + : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { + p_blobs_.reset(new std::unordered_map>()); } -template -void MKLDNNDeviceContext::AddElement(const std::string& op_key, - const T& value) { - if (GetElement(op_key)) { - return; - } - GetElementPool().emplace(op_key, std::move(value)); -} +void MKLDNNDeviceContext::SetBlob(const std::string& name, + std::shared_ptr data) const { + std::unordered_map>* p; + p = p_blobs_.get(); -template -const T& MKLDNNDeviceContext::GetElement(const std::string& op_key) const { - auto it = GetElementPool().find(op_key); - return it == GetElementPool().end() ? nullptr : it->second; -} + auto it = p->find(name); -template <> -const std::unordered_map>& -MKLDNNDeviceContext::GetElementPool() const { - return memory_pool_; -} + if (it == p->end()) { + (*p)[name] = data; // create new blob + } else { + it->second = data; // set data to existing blob + } -template <> -const std::unordered_map>& -MKLDNNDeviceContext::GetElementPool() const { - return primitive_pool_; + return; } -template <> -const std::unordered_map>& -MKLDNNDeviceContext::GetElementPool() const { - return primitive_desc_pool_; -} +std::shared_ptr MKLDNNDeviceContext::GetBlob( + const std::string& name) const { + std::unordered_map>* p; + p = p_blobs_.get(); -void MKLDNNDeviceContext::Execute(bool block) { - if (pipeline_.empty()) { - return; - } - ResetStream(); - stream_->submit(pipeline_).wait(block); - ready_ = false; - pipeline_.clear(); -} + auto it = p->find(name); -void MKLDNNDeviceContext::ResetStream() { - if (ready_) { - return; + if (it != p->end()) { + return it->second; } - // TODO(TJ): change me when mkldnn have specific method to reset this state - stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); - ready_ = true; + + return nullptr; } #endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a294ba5101528c9ac0007bdcfc5255a0c2674aad..e779644190de1246cd650fbf91eeaeb03494643f 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -22,7 +22,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" +#include #endif #include "paddle/fluid/platform/enforce.h" @@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return place in the device context. */ Place GetPlace() const override; + /*! \brief Return the max physical thread count in the device context */ + int GetMaxPhysicalThreadCount() const; + /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; + + int multi_process; + int max_threads_per_mp; }; template <> @@ -114,46 +120,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); - /* \brief Add new element: memory, primitive or primitive desc */ - template - void AddElement(const std::string& op_key, const T& value); - - /* \brief Get existed element: memory, primitive or primitive desc */ - template - const T& GetElement(const std::string& op_key) const; - - /* \brief Get element pool: memory, primitive or primitive desc pool */ - template - const std::unordered_map>& - GetElementPool() const; - /* \brief Get the active engine */ - const MKLDNNEngine& engine() const { return *engine_; } - - /* \brief Submit primitive to pipeline */ - void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); } + const mkldnn::engine& GetEngine() const { return engine_; } - /*! \brief Execute all submitted primitives in pipeline */ - void Execute(bool block = true); + // Set data to blob (i.e. name/data pair). Create blob if not existing + void SetBlob(const std::string& name, std::shared_ptr data) const; - protected: - /*! \brief Reset the stream to prepare next exectue */ - void ResetStream(); + // Find a saved blob. Return nullptr if not found + std::shared_ptr GetBlob(const std::string& name) const; private: - std::unordered_map> - memory_pool_; - std::unordered_map> - primitive_pool_; - std::unordered_map> - primitive_desc_pool_; - std::vector pipeline_; - MKLDNNStreamPtr stream_; - MKLDNNEnginePtr engine_; - bool ready_; + mkldnn::engine engine_; + std::shared_ptr>> + p_blobs_; }; #endif diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 05e1eae853e20b3fd86438c03f52628179a311ca..da4041bad0d82fe1c8c7a12fd0c7177e6dbddef3 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -33,6 +33,26 @@ int GetCUDADeviceCount() { return count; } +int GetCUDAMultiProcessors(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMultiProcessors"); + return count; +} + +int GetCUDAMaxThreadsPerMultiProcessor(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE(cudaDeviceGetAttribute( + &count, cudaDevAttrMaxThreadsPerMultiProcessor, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMaxThreadsPerMultiProcessor"); + return count; +} + int GetCurrentDeviceId() { int device_id; PADDLE_ENFORCE( diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 3d4883d8078daa2b55d8ea792b47e93e4f4feec8..c38ccf0f2ade1d2405177b541b33fd84283726ff 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse = //! Get the total number of GPU devices in system. int GetCUDADeviceCount(); +//! Get the MultiProcessors of the ith GPU. +int GetCUDAMultiProcessors(int i); + +//! Get the MaxThreads of each MultiProcessor of the ith GPU. +int GetCUDAMaxThreadsPerMultiProcessor(int i); + //! Get the current GPU device id in system. int GetCurrentDeviceId(); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 6d71f352c6eda12e2bc032ecbc26cfefe320f703..90b78142b845e7e12c0c7dfb391f6aa3bd848436 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -16,12 +16,15 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/operator.h" + namespace paddle { namespace platform { using MKLDNNStream = mkldnn::stream; using MKLDNNEngine = mkldnn::engine; using MKLDNNMemory = mkldnn::memory; +using MKLDNNMemoryDescriptor = mkldnn::memory::desc; using MKLDNNPrimitive = mkldnn::primitive; using MKLDNNPrimitiveDesc = mkldnn::handle; @@ -31,5 +34,17 @@ typedef std::unique_ptr MKLDNNMemoryPtr; typedef std::unique_ptr MKLDNNPrimitivePtr; typedef std::unique_ptr MKLDNNPrimitiveDescPtr; +inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, + mkldnn::memory::data_type data_type, + mkldnn::memory::format format) { + mkldnn::memory::dims tz = dims; + return mkldnn::memory::desc({tz}, data_type, format); +} + +inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_mkldnn = ctx.Attr("use_mkldnn"); + return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 094f9224f719f6656fd0c44aa4a620730432ccac..28ef3e04b1c50e0d42eeb27608259c6449429da5 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -178,7 +178,7 @@ void EnableProfiler(ProfilerState state) { } #ifdef PADDLE_WITH_CUDA if (g_state == ProfilerState::kCUDA) { - // Generate some dummy evenets first to reduce the startup overhead. + // Generate some dummy events first to reduce the startup overhead. for (int i = 0; i < 5; i++) { ForEachDevice([](int d) { DeviceContext* dev_ctx = new CUDADeviceContext(CUDAPlace(d)); diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 06319fc638984f8f8ed897c362f516e1534bc8db..6be2bd8fad9e33cf4e1dcafdd6b8f39111bdbe88 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -213,7 +213,7 @@ function gen_fluid_inference_lib() { if [ ${WITH_C_API:-OFF} == "OFF" ] ; then cat <