diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index a25cff5fc567f22d4573625487f31bd4192bb172..5759e5c489724332793bf103b7aacf7ffb068611 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -36,7 +36,8 @@ MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") -INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) +INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers. +INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include mkldnn.h IF(${CBLAS_PROVIDER} STREQUAL "MKLML") SET(MKLDNN_DEPENDS ${MKLML_PROJECT}) diff --git a/doc/fluid/design/muti_devices/kernel_hint_design.md b/doc/fluid/design/muti_devices/kernel_hint_design.md index 728c8f0b964c02c1efa019945f7427fa879d3aa1..58e44b64169d8c942174de86986403570b271641 100644 --- a/doc/fluid/design/muti_devices/kernel_hint_design.md +++ b/doc/fluid/design/muti_devices/kernel_hint_design.md @@ -1,4 +1,6 @@ -# Problem +# Kernel Hint Design + +## Problem In PaddlePaddle's [Design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md), one Operator may have multiple kernels. Users may have some personal preference to choose a certain type of kernel for an operator, such as `force_cpu` to choose a CPU kernel, `use_cudnn` to choose a CUDNN kernel, we need to provide a way for users to do this. In the current design, we use KernelType to describe one kernel. diff --git a/doc/fluid/design/muti_devices/kernel_selection.md b/doc/fluid/design/muti_devices/kernel_selection.md index 39ea2b00090a864f95610d6d2846ca5e5c904e78..967317d5d2eeb818ab14faabca342cc8c4ed717e 100644 --- a/doc/fluid/design/muti_devices/kernel_selection.md +++ b/doc/fluid/design/muti_devices/kernel_selection.md @@ -1,4 +1,6 @@ -# Background +# Kernel Selection + +## Background Every operator has many kernels because there are multiple data types, places, data layout, library type that Fluid supports. We use the `OpKernelType ` to describe kernel types that operators can hold. The `OpKernelType ` is as follows: diff --git a/doc/v2/build_and_install/index_en.rst b/doc/v2/build_and_install/index_en.rst index 7e0ca5bcbdbad0a3c97c0045bb57b51137668161..5b3de0f8c3e5496060646b5ddb080d0d338a8bfa 100644 --- a/doc/v2/build_and_install/index_en.rst +++ b/doc/v2/build_and_install/index_en.rst @@ -1,32 +1,56 @@ -Install and Build -================= +install and Compile +========== .. _install_steps: -Install Steps -++++++++ +PaddlePaddle provides various methods of installation for many different users -You can choose either pip or Docker to complete your install: +Focus on Deep Learning Model Development +----------------- + +PaddlePaddle provides lots of packages of python wheel , that pip can install: .. toctree:: - :maxdepth: 1 + :maxdepth: 1 - pip_install_en.rst - docker_install_en.rst + pip_install_en.rst -Build from Source ------------------ +This is the most convenient way of installation. Please choose the right installation package with machine configure and system. + +Follow the Bottom Frame +---------- + +PaddlePaddle also supports installation using Docker. Please refer to the tutorial below: + +.. toctree:: + :maxdepth: 1 + + docker_install_en.rst -.. warning:: +We recommend running PaddlePaddle in Docker. This method has the following advantages: - We recommend to directly install via above installation steps, you'll only need to build PaddlePaddle from source when you need a modifed binary. +- Does not require installation of third-party dependencies. +- Easy to share runtime environment. -.. toctree:: +Lastly, users can also compile and install PaddlePaddle from source code. The instructions are below: + +.. toctree:: :maxdepth: 1 - build_from_source_en.md + build_from_source_en.rst + +.. warning:: + + One caveat with this approach is that developers will have to download, compile and install all third-party dependencies. Thus this process of installation is more time consuming. + FAQ -++++++++++ +----------- + +For any problems during installation, please refer to the page below for answers: + +:ref:`常见问题解答 ` + +If the problem still persists, you are welcome to seek assistance from the PaddlePaddle community: -`FAQ `_ +`创建issue `_ diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/paddle/fluid/framework/data_device_transform_test.cu index e896a06162527ed0289767901f4b4a33fcd2875f..a66525303da58601f85c40c41854edaf22c3d4ea 100644 --- a/paddle/fluid/framework/data_device_transform_test.cu +++ b/paddle/fluid/framework/data_device_transform_test.cu @@ -105,7 +105,7 @@ static void BuildVar(const std::string& param_name, TEST(Operator, CPUtoGPU) { using namespace paddle::framework; using namespace paddle::platform; - InitDevices(); + InitDevices(true); paddle::framework::Scope scope; paddle::platform::CPUPlace cpu_place; diff --git a/paddle/fluid/framework/init.cc b/paddle/fluid/framework/init.cc index 3c0d93642ac41e8d90f9a248e81cea7a4fe12293..75c557fa4243f4bd984314fac298e9335108e7a9 100644 --- a/paddle/fluid/framework/init.cc +++ b/paddle/fluid/framework/init.cc @@ -64,7 +64,7 @@ void InitP2P(int count) { #endif } -void InitDevices() { +void InitDevices(bool init_p2p) { /*Init all avaiable devices by default */ std::vector places; @@ -85,7 +85,9 @@ void InitDevices() { for (int i = 0; i < count; ++i) { places.emplace_back(platform::CUDAPlace(i)); } - InitP2P(count); + if (init_p2p) { + InitP2P(count); + } platform::DeviceContextPool::Init(places); } diff --git a/paddle/fluid/framework/init.h b/paddle/fluid/framework/init.h index 7d86d1581190780f513776c69b18ad41eb2ce14d..fae98a60b5111465375404609905980177f613b1 100644 --- a/paddle/fluid/framework/init.h +++ b/paddle/fluid/framework/init.h @@ -24,7 +24,7 @@ void InitGflags(std::vector &argv); void InitGLOG(const std::string &prog_name); -void InitDevices(); +void InitDevices(bool init_p2p); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/init_test.cc b/paddle/fluid/framework/init_test.cc index 2a03f0afe657e4b3ac173e8718dd6f6f81ee5e6a..928e2d14abea604cf483f4bc1e1c58fbae04dd21 100644 --- a/paddle/fluid/framework/init_test.cc +++ b/paddle/fluid/framework/init_test.cc @@ -21,7 +21,7 @@ TEST(InitDevices, CPU) { using paddle::platform::DeviceContextPool; #ifndef PADDLE_WITH_CUDA - InitDevices(); + InitDevices(true); DeviceContextPool& pool = DeviceContextPool::Instance(); ASSERT_EQ(pool.size(), 1U); #endif @@ -33,7 +33,7 @@ TEST(InitDevices, CUDA) { #ifdef PADDLE_WITH_CUDA int count = paddle::platform::GetCUDADeviceCount(); - InitDevices(); + InitDevices(true); DeviceContextPool& pool = DeviceContextPool::Instance(); ASSERT_EQ(pool.size(), 1U + static_cast(count)); #endif diff --git a/paddle/fluid/framework/lod_tensor_test.cu b/paddle/fluid/framework/lod_tensor_test.cu index be65da5ba230e4bb15b09a07431d3107ffe19522..e3efbe4c464493af87e33510647d8c67d457a76d 100644 --- a/paddle/fluid/framework/lod_tensor_test.cu +++ b/paddle/fluid/framework/lod_tensor_test.cu @@ -30,7 +30,7 @@ __global__ void test(size_t* a, int size) { } TEST(LoD, data) { - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); paddle::framework::LoD lod{{0, 1, 2}}; lod.push_back({0, 2, 4, 5}); @@ -46,7 +46,7 @@ TEST(LoD, data) { } TEST(LoDTensor, LoDInGPU) { - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); paddle::framework::LoDTensor lod_tensor; paddle::platform::CUDAPlace place(0); diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 44ca4d7ca564515ae267c5949d29feaf22790251..25f622b725277ac9bcca4622902162f3edf147e8 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -72,7 +72,7 @@ REGISTER_OP_WITHOUT_GRADIENT(test_operator, paddle::framework::OpWithoutKernelCheckerMaker); TEST(OperatorBase, all) { - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); paddle::framework::proto::OpDesc op_desc; op_desc.set_type("test_operator"); BuildVar("input", {"IN1"}, op_desc.add_inputs()); @@ -198,7 +198,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, // test with single input TEST(OpKernel, all) { - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); paddle::framework::proto::OpDesc op_desc; op_desc.set_type("op_with_kernel"); BuildVar("x", {"IN1"}, op_desc.add_inputs()); @@ -228,7 +228,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, TEST(OpKernel, multi_inputs) { using namespace paddle::framework; - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); proto::OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); @@ -269,7 +269,7 @@ class OperatorClone : public paddle::framework::OperatorBase { }; TEST(Operator, Clone) { - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); OperatorClone a("ABC", paddle::framework::VariableNameMap{}, paddle::framework::VariableNameMap{}, paddle::framework::AttributeMap{}); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 049731c7216e542dedcf8754eef79f0a672291d6..77d17fbbccca0292e21acd5e8fa90448527b95c0 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -85,9 +85,9 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { } const std::vector ProgramDesc::GetFeedTargetNames() { - BlockDesc *global_block = blocks_[0].get(); + auto &global_block = Block(0); std::vector feed_target_names; - for (auto *op : global_block->AllOps()) { + for (auto *op : global_block.AllOps()) { if (op->Type() == kFeedOpType) { feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]); } @@ -96,9 +96,9 @@ const std::vector ProgramDesc::GetFeedTargetNames() { } const std::vector ProgramDesc::GetFetchTargetNames() { - BlockDesc *global_block = blocks_[0].get(); + auto &global_block = Block(0); std::vector fetch_target_names; - for (auto *op : global_block->AllOps()) { + for (auto *op : global_block.AllOps()) { if (op->Type() == kFetchOpType) { fetch_target_names.push_back(op->Input("X")[0]); } @@ -106,5 +106,43 @@ const std::vector ProgramDesc::GetFetchTargetNames() { return fetch_target_names; } +void ProgramDesc::SetFeedHolderName(const std::string &feed_holder_name) { + auto *global_block = MutableBlock(0); + int index = 0; + for (auto *op : global_block->AllOps()) { + if (op->Type() == kFeedOpType) { + // Unify the input's name of all feed_ops to feed_holder_name + global_block->RemoveVar(op->Input("X")[0]); + op->SetInput("X", {feed_holder_name}); + op->SetAttr("col", {index}); + op->CheckAttrs(); + index++; + } + } + + auto *feed_holder = global_block->Var(feed_holder_name); + feed_holder->SetType(proto::VarType::FEED_MINIBATCH); + feed_holder->SetPersistable(true); +} + +void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) { + auto *global_block = MutableBlock(0); + int index = 0; + for (auto *op : global_block->AllOps()) { + if (op->Type() == kFetchOpType) { + // Unify the output's name of all fetch_ops to fetch_holder_name + global_block->RemoveVar(op->Output("Out")[0]); + op->SetOutput("Out", {fetch_holder_name}); + op->SetAttr("col", {index}); + op->CheckAttrs(); + index++; + } + } + + auto *fetch_holder = global_block->Var(fetch_holder_name); + fetch_holder->SetType(proto::VarType::FETCH_LIST); + fetch_holder->SetPersistable(true); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h index 538a0372116e6f90fd2fae5f00097b8efc5dcb5c..4288081be72c44c0fc3584b50c41a270eac9e204 100644 --- a/paddle/fluid/framework/program_desc.h +++ b/paddle/fluid/framework/program_desc.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/framework.pb.h" @@ -52,9 +53,26 @@ class ProgramDesc { proto::ProgramDesc *Proto(); + // The output variable of feed_op is referenced as feed_target. + // This function is used to collect the output variable's name of all + // feed_ops. const std::vector GetFeedTargetNames(); + + // The input variable of fetch_op is referenced as fetch_target. + // This function is used to collect the input variable's name of all + // fetch_ops. const std::vector GetFetchTargetNames(); + // The input variable of feed_op that holds input Tensor provided by users is + // referenced as feed_holder. + // This function is used to change or unify the feed_holder variables' name. + void SetFeedHolderName(const std::string &feed_holder_name); + + // The output variable of fetch_op that holds output Tensor needed by users is + // referenced as fetch_holder. + // This function is used to change or unify the fetch_holder variables' name. + void SetFetchHolderName(const std::string &fetch_holder_name); + private: proto::ProgramDesc desc_; diff --git a/paddle/fluid/inference/tests/book/test_inference_fit_a_line.cc b/paddle/fluid/inference/tests/book/test_inference_fit_a_line.cc index 3e77dc166c355bc141563eda4705ca8d75226ac4..2c5b66a32903f4ffdedb074b31aec53ae6cacaf3 100644 --- a/paddle/fluid/inference/tests/book/test_inference_fit_a_line.cc +++ b/paddle/fluid/inference/tests/book/test_inference_fit_a_line.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" +#include "paddle/fluid/inference/tests/test_multi_thread_helper.h" DEFINE_string(dirname, "", "Directory of the inference model."); @@ -26,32 +27,63 @@ TEST(inference, fit_a_line) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - paddle::framework::LoDTensor input; - // The second dim of the input tensor should be 13 - // The input data should be >= 0 - int64_t batch_size = 10; - SetupTensor(&input, {batch_size, 13}, static_cast(0), - static_cast(10)); - std::vector cpu_feeds; - cpu_feeds.push_back(&input); + for (int num_threads : {1, 2}) { + std::vector> cpu_feeds; + cpu_feeds.resize(num_threads); + for (int i = 0; i < num_threads; ++i) { + auto* input = new paddle::framework::LoDTensor(); + // The second dim of the input tensor should be 13 + // The input data should be >= 0 + int64_t batch_size = 10; + SetupTensor(input, {batch_size, 13}, static_cast(0), + static_cast(10)); + cpu_feeds[i].push_back(input); + } - paddle::framework::LoDTensor output1; - std::vector cpu_fetchs1; - cpu_fetchs1.push_back(&output1); + std::vector> cpu_fetchs1; + cpu_fetchs1.resize(num_threads); + for (int i = 0; i < num_threads; ++i) { + auto* output = new paddle::framework::LoDTensor(); + cpu_fetchs1[i].push_back(output); + } - // Run inference on CPU - TestInference(dirname, cpu_feeds, cpu_fetchs1); - LOG(INFO) << output1.dims(); + // Run inference on CPU + LOG(INFO) << "--- CPU Runs (num_threads: " << num_threads << "): ---"; + if (num_threads == 1) { + TestInference(dirname, cpu_feeds[0], + cpu_fetchs1[0]); + } else { + TestMultiThreadInference( + dirname, cpu_feeds, cpu_fetchs1, num_threads); + } #ifdef PADDLE_WITH_CUDA - paddle::framework::LoDTensor output2; - std::vector cpu_fetchs2; - cpu_fetchs2.push_back(&output2); + std::vector> cpu_fetchs2; + cpu_fetchs2.resize(num_threads); + for (int i = 0; i < num_threads; ++i) { + auto* output = new paddle::framework::LoDTensor(); + cpu_fetchs2[i].push_back(output); + } - // Run inference on CUDA GPU - TestInference(dirname, cpu_feeds, cpu_fetchs2); - LOG(INFO) << output2.dims(); + // Run inference on CUDA GPU + LOG(INFO) << "--- GPU Runs (num_threads: " << num_threads << "): ---"; + if (num_threads == 1) { + TestInference(dirname, cpu_feeds[0], + cpu_fetchs2[0]); + } else { + TestMultiThreadInference( + dirname, cpu_feeds, cpu_fetchs2, num_threads); + } - CheckError(output1, output2); + for (int i = 0; i < num_threads; ++i) { + CheckError(*cpu_fetchs1[i][0], *cpu_fetchs2[i][0]); + delete cpu_fetchs2[i][0]; + } #endif + + for (int i = 0; i < num_threads; ++i) { + delete cpu_feeds[i][0]; + delete cpu_fetchs1[i][0]; + } + } // num_threads-loop } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index aae34ceda07fea6e881cf61b3755ec45d1d6f2dc..064e400f0c750872ab2142c5fc8e28dd3da85b1a 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -25,7 +25,8 @@ limitations under the License. */ template void SetupTensor(paddle::framework::LoDTensor* input, paddle::framework::DDim dims, T lower, T upper) { - std::mt19937 rng(100); // An arbitrarily chosen but fixed seed. + static unsigned int seed = 100; + std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); T* input_ptr = input->mutable_data(dims, paddle::platform::CPUPlace()); diff --git a/paddle/fluid/inference/tests/test_multi_thread_helper.h b/paddle/fluid/inference/tests/test_multi_thread_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..56745f115db231d4350da72b7de7967175ac9fe8 --- /dev/null +++ b/paddle/fluid/inference/tests/test_multi_thread_helper.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/io.h" + +void ThreadedRunInference( + const std::unique_ptr& inference_program, + paddle::framework::Executor* executor, paddle::framework::Scope* scope, + const int thread_id, + const std::vector& cpu_feeds, + const std::vector& cpu_fetchs) { + auto copy_program = std::unique_ptr( + new paddle::framework::ProgramDesc(*inference_program)); + + std::string feed_holder_name = "feed_" + paddle::string::to_string(thread_id); + std::string fetch_holder_name = + "fetch_" + paddle::string::to_string(thread_id); + copy_program->SetFeedHolderName(feed_holder_name); + copy_program->SetFetchHolderName(fetch_holder_name); + + // 3. Get the feed_target_names and fetch_target_names + const std::vector& feed_target_names = + copy_program->GetFeedTargetNames(); + const std::vector& fetch_target_names = + copy_program->GetFetchTargetNames(); + + // 4. Prepare inputs: set up maps for feed targets + std::map feed_targets; + for (size_t i = 0; i < feed_target_names.size(); ++i) { + // Please make sure that cpu_feeds[i] is right for feed_target_names[i] + feed_targets[feed_target_names[i]] = cpu_feeds[i]; + } + + // 5. Define Tensor to get the outputs: set up maps for fetch targets + std::map fetch_targets; + for (size_t i = 0; i < fetch_target_names.size(); ++i) { + fetch_targets[fetch_target_names[i]] = cpu_fetchs[i]; + } + + // 6. Run the inference program + executor->Run(*copy_program, scope, feed_targets, fetch_targets, true, + feed_holder_name, fetch_holder_name); +} + +template +void TestMultiThreadInference( + const std::string& dirname, + const std::vector>& cpu_feeds, + const std::vector>& cpu_fetchs, + const int num_threads) { + // 1. Define place, executor, scope + auto place = Place(); + auto executor = paddle::framework::Executor(place); + auto* scope = new paddle::framework::Scope(); + + // 2. Initialize the inference_program and load parameters + std::unique_ptr inference_program = + paddle::inference::Load(executor, *scope, dirname); + + std::vector threads; + for (int i = 0; i < num_threads; ++i) { + threads.push_back(new std::thread( + ThreadedRunInference, std::ref(inference_program), &executor, scope, i, + std::ref(cpu_feeds[i]), std::ref(cpu_fetchs[i]))); + } + for (int i = 0; i < num_threads; ++i) { + threads[i]->join(); + delete threads[i]; + } + + delete scope; +} diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index a6d9ce0f041b859ecf6b3de902a9d1f132a4c76e..b261144f3d7836801e0b7a45a1478d3b801db86d 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -662,14 +662,3 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad, ops::grad_functor>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); - -REGISTER_OP_CPU_KERNEL(relu, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CPU_KERNEL( - relu_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 7709a551dc155e1f3cd2a19a689999608f497beb..4f745553c14fc1391bc65d4f7e4f9bd3b5a881c2 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,31 +14,19 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; - -#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \ - REGISTER_OP_CUDA_KERNEL( \ - act_type, ops::ActivationKernel>, \ - ops::ActivationKernel>); \ - REGISTER_OP_CUDA_KERNEL( \ - act_type##_grad, \ - ops::ActivationGradKernel>, \ - ops::ActivationGradKernel>, \ + ops::ActivationKernel>, \ + ops::ActivationKernel>); \ + REGISTER_OP_CUDA_KERNEL( \ + act_type##_grad, ops::ActivationGradKernel>, \ + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); - -REGISTER_OP_CUDA_KERNEL( - relu, ops::ActivationKernel>, - ops::ActivationKernel>, - ops::ActivationKernel>); -REGISTER_OP_CUDA_KERNEL( - relu_grad, ops::ActivationGradKernel>, - ops::ActivationGradKernel>); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index c4efbcd3f977ee285e13223d7e0ca420aec63b98..43856780bf9357281ac4af2968950da15426e5c8 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,9 +12,11 @@ limitations under the License. */ #pragma once #include #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/float16.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -338,11 +337,25 @@ struct Sine { HOSTDEVICE T operator()(const T& val) const { return sin(val); } }; +template <> +struct Sine { + HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { + return platform::float16(sin(static_cast(val))); + } +}; + template struct Cosine { HOSTDEVICE T operator()(const T& val) const { return cos(val); } }; +template <> +struct Cosine { + HOSTDEVICE platform::float16 operator()(const platform::float16& val) const { + return platform::float16(cos(static_cast(val))); + } +}; + // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { @@ -826,6 +839,7 @@ struct SwishGradFunctor : public BaseActivationFunctor { __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, ReluFunctor, ReluGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index feb4f367008d76d86a93c561a8eec1f2485c99d6..f03165fae5ca16c5c263ce0683af7ec56e6a3766 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -8,10 +8,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "paddle/fluid/platform/device_context.h" + +#include #include +#include + #include "paddle/fluid/memory/memory.h" + namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 6b796d92d09cdde2db60c7651c03d3782ff013e6..b17558337914e0ca8fdba283edf4024d94e85f0f 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -8,11 +8,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once #include +#include #include +#include #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/cublas.h" diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 9d8d07362ce3a0d0c2a009c9844db0a3bdaf01cb..fa806aba6d8747beebc3eed2c661b326dd62fd76 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#include "gtest/gtest.h" #include "paddle/fluid/platform/device_context.h" +#include + #include "glog/logging.h" +#include "gtest/gtest.h" TEST(Device, Init) { using paddle::platform::DeviceContext; diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 3b4437f576e1c2e931a86ec6d5e823ec1f344c52..c9e10631680a6ea3876f555a3a6e6c12f79b39d5 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -11,15 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "paddle/fluid/platform/device_tracer.h" -#include + +#include #include #include -#include +#include // NOLINT #include -#include +#include +#include // NOLINT +#include + #include "glog/logging.h" +#include "google/protobuf/text_format.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/string/printf.h" @@ -123,7 +127,7 @@ void DisableActivity() { void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size, size_t *maxNumRecords) { - uint8_t *buf = (uint8_t *)malloc(kBufSize + kAlignSize); + uint8_t *buf = reinterpret_cast(malloc(kBufSize + kAlignSize)); *size = kBufSize; *buffer = ALIGN_BUFFER(buf, kAlignSize); *maxNumRecords = 0; diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index deb3d23f786353b8e7a2f28d094e364158885a34..0375c7439c29d4122e8ff6b58734dad4f504b7a2 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -11,8 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once + +#include + #include "paddle/fluid/platform/dynload/cupti.h" #include "paddle/fluid/platform/profiler.pb.h" diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index e77f768bf9f437a289b16d2ec9597c570b0a9ad2..673e1bcae4af6d039bc969f1de6e4bcab3748cb5 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -1003,6 +1003,46 @@ HOSTDEVICE inline float16 exp(const float16& a) { return float16(::expf(static_cast(a))); } +template <> +HOSTDEVICE inline float16 log(const float16& a) { + return float16(::logf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 tanh(const float16& a) { + return float16(::tanhf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 sqrt(const float16& a) { + return float16(::sqrtf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 ceil(const float16& a) { + return float16(::ceilf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 floor(const float16& a) { + return float16(::floorf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 round(const float16& a) { + return float16(::roundf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 pow(const float16& a, const float16& b) { + return float16(::powf(static_cast(a), static_cast(b))); +} + +template <> +HOSTDEVICE inline float16 abs(const float16& a) { + return float16(::fabs(static_cast(a))); +} + } // namespace numext } // namespace Eigen diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 90b78142b845e7e12c0c7dfb391f6aa3bd848436..de8056237fb022f62488e0fedf9a4f67e4601072 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once -#include +#include +#include "mkldnn/include/mkldnn.hpp" #include "paddle/fluid/framework/operator.h" namespace paddle { diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index b25206ff35cc87dcdd363bc0de54530f629d73ed..412cdda286c3a77af002fdc5eb6a5ae440606b82 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -15,8 +15,11 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include #include +#include #include #include +#include // NOLINT +#include #ifdef PADDLE_WITH_CUDA #include #endif // PADDLE_WITH_CUDA @@ -28,10 +31,10 @@ limitations under the License. */ namespace paddle { namespace platform { +struct EventList; + // The profiler state, the initial value is ProfilerState::kDisabled static ProfilerState g_state = ProfilerState::kDisabled; -// To record which timer the profiler used, CUDA or CPU. -static std::string g_profiler_place = ""; // The thread local event list only can be accessed by the specific thread // The thread index of each thread static thread_local int32_t g_thread_id; @@ -45,6 +48,39 @@ static std::list> g_all_event_lists; // The thread local event list only can be accessed by the specific thread static thread_local std::shared_ptr g_event_list; +struct EventList { + constexpr static size_t kMB = 1024 * 1024; + constexpr static size_t kEventBlockSize = 16 * kMB; + constexpr static size_t kEventSize = sizeof(Event); + constexpr static size_t kEventAlign = alignof(Event); + constexpr static size_t kNumBlock = + kEventBlockSize / + ((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign); + + template + void Record(Args&&... args) { + if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) { + event_blocks.emplace_front(); + event_blocks.front().reserve(kNumBlock); + } + event_blocks.front().emplace_back(std::forward(args)...); + } + + std::vector Reduce() { + std::vector result; + for (auto& block : event_blocks) { + result.insert(result.begin(), std::make_move_iterator(block.begin()), + std::make_move_iterator(block.end())); + } + event_blocks.clear(); + return result; + } + + void Clear() { event_blocks.clear(); } + + std::forward_list> event_blocks; +}; + inline uint64_t GetTimeInNsec() { using clock = std::conditional(tv.tv_sec) * 1000000 + tv.tv_usec); } -Event::Event(EventKind kind, std::string name, uint32_t thread_id, +Event::Event(EventType type, std::string name, uint32_t thread_id, const DeviceContext* dev_ctx) - : kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) { + : type_(type), name_(name), thread_id_(thread_id), has_cuda_(false) { #ifdef PADDLE_WITH_CUDA has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false; if (has_cuda_) { @@ -76,17 +112,7 @@ Event::Event(EventKind kind, std::string name, uint32_t thread_id, cpu_ns_ = GetTimeInNsec(); } -std::string Event::kind() const { - switch (kind_) { - case EventKind::kMark: - return "mark"; - case EventKind::kPushRange: - return "push"; - case EventKind::kPopRange: - return "pop"; - } - PADDLE_THROW("Unknown EventKind."); -} +const EventType& Event::type() const { return type_; } double Event::CpuElapsedMs(const Event& e) const { return (e.cpu_ns_ - cpu_ns_) / (1000000.0); @@ -129,15 +155,15 @@ inline EventList& GetEventList() { } void Mark(const std::string& name, const DeviceContext* dev_ctx) { - GetEventList().Record(EventKind::kMark, name, g_thread_id, dev_ctx); + GetEventList().Record(EventType::kMark, name, g_thread_id, dev_ctx); } void PushEvent(const std::string& name, const DeviceContext* dev_ctx) { - GetEventList().Record(EventKind::kPushRange, name, g_thread_id, dev_ctx); + GetEventList().Record(EventType::kPushRange, name, g_thread_id, dev_ctx); } void PopEvent(const std::string& name, const DeviceContext* dev_ctx) { - GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx); + GetEventList().Record(EventType::kPopRange, name, g_thread_id, dev_ctx); } RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx) @@ -197,12 +223,7 @@ void EnableProfiler(ProfilerState state) { "The profiling state should be disabled when calling ", "EnableProfiler."); g_state = state; - if (g_state == ProfilerState::kCUDA) { - g_profiler_place = "CUDA"; - } else if (g_state == ProfilerState::kCPU) { - g_profiler_place = "CPU"; - } else { - g_profiler_place = "All"; + if (g_state == ProfilerState::kAll) { GetDeviceTracer()->Enable(); } #ifdef PADDLE_WITH_CUDA @@ -240,27 +261,63 @@ std::vector> GetAllEvents() { return result; } -void DisableProfiler(EventSortingKey sorted_key, - const std::string& profile_path) { - PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, - "Can't disable profiling, since it's not starting."); - // Mark the profiling stop. - Mark("_stop_profiler_", nullptr); - g_state = ProfilerState::kDisabled; +// The information of each event given in the profiling report +struct EventItem { + std::string name; + int calls; + double total_time; + double min_time; + double max_time; + double ave_time; +}; + +// Print results +void PrintProfiler(const std::vector>& events_table, + const std::string& sorted_domain, const size_t name_width, + const size_t data_width) { + // Output header information + std::cout << "\n------------------------->" + << " Profiling Report " + << "<-------------------------\n\n"; + std::string place; + if (g_state == ProfilerState::kCPU) { + place = "CPU"; + } else if (g_state == ProfilerState::kCUDA) { + place = "CUDA"; + } else if (g_state == ProfilerState::kAll) { + place = "All"; + } else { + PADDLE_THROW("Invalid profiler state"); + } - std::vector> all_events = GetAllEvents(); - ParseEvents(all_events, sorted_key); - ResetProfiler(); - DeviceTracer* tracer = GetDeviceTracer(); - if (g_profiler_place == "All" && tracer && tracer->IsEnabled()) { - tracer->Disable(); - tracer->GenProfile(profile_path); + std::cout << "Place: " << place << std::endl; + std::cout << "Time unit: ms" << std::endl; + std::cout << "Sorted by " << sorted_domain + << " in descending order in the same thread\n\n"; + // Output events table + std::cout.setf(std::ios::left); + std::cout << std::setw(name_width) << "Event" << std::setw(data_width) + << "Calls" << std::setw(data_width) << "Total" + << std::setw(data_width) << "Min." << std::setw(data_width) + << "Max." << std::setw(data_width) << "Ave." << std::endl; + for (size_t i = 0; i < events_table.size(); ++i) { + for (size_t j = 0; j < events_table[i].size(); ++j) { + const EventItem& event_item = events_table[i][j]; + std::cout << std::setw(name_width) << event_item.name + << std::setw(data_width) << event_item.calls + << std::setw(data_width) << event_item.total_time + << std::setw(data_width) << event_item.min_time + << std::setw(data_width) << event_item.max_time + << std::setw(data_width) << event_item.ave_time << std::endl; + } } + std::cout << std::endl; } -void ParseEvents(std::vector>& events, - EventSortingKey sorted_by) { - if (g_profiler_place == "") return; +// Parse the event list and output the profiling report +void ParseEvents(const std::vector>& events, + EventSortingKey sorted_by = EventSortingKey::kDefault) { + if (g_state == ProfilerState::kDisabled) return; std::string sorted_domain; std::function sorted_func; @@ -307,9 +364,9 @@ void ParseEvents(std::vector>& events, std::unordered_map event_idx; for (size_t j = 0; j < events[i].size(); j++) { - if (events[i][j].kind() == "push") { + if (events[i][j].type() == EventType::kPushRange) { pushed_events.push_back(events[i][j]); - } else if (events[i][j].kind() == "pop") { + } else if (events[i][j].type() == EventType::kPopRange) { std::list::reverse_iterator rit = pushed_events.rbegin(); while (rit != pushed_events.rend() && rit->name() != events[i][j].name()) { @@ -317,10 +374,10 @@ void ParseEvents(std::vector>& events, } if (rit != pushed_events.rend()) { - double event_time = - (g_profiler_place == "CUDA" || g_profiler_place == "All") - ? rit->CudaElapsedMs(events[i][j]) - : rit->CpuElapsedMs(events[i][j]); + double event_time = (g_state == ProfilerState::kCUDA || + g_state == ProfilerState::kAll) + ? rit->CudaElapsedMs(events[i][j]) + : rit->CpuElapsedMs(events[i][j]); std::string event_name = "thread" + std::to_string(rit->thread_id()) + "::" + rit->name(); @@ -376,35 +433,22 @@ void ParseEvents(std::vector>& events, PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12); } -void PrintProfiler(std::vector>& events_table, - std::string& sorted_domain, const size_t name_width, - const size_t data_width) { - // Output header information - std::cout << "\n------------------------->" - << " Profiling Report " - << "<-------------------------\n\n"; - std::cout << "Place: " << g_profiler_place << std::endl; - std::cout << "Time unit: ms" << std::endl; - std::cout << "Sorted by " << sorted_domain - << " in descending order in the same thread\n\n"; - // Output events table - std::cout.setf(std::ios::left); - std::cout << std::setw(name_width) << "Event" << std::setw(data_width) - << "Calls" << std::setw(data_width) << "Total" - << std::setw(data_width) << "Min." << std::setw(data_width) - << "Max." << std::setw(data_width) << "Ave." << std::endl; - for (size_t i = 0; i < events_table.size(); ++i) { - for (size_t j = 0; j < events_table[i].size(); ++j) { - EventItem& event_item = events_table[i][j]; - std::cout << std::setw(name_width) << event_item.name - << std::setw(data_width) << event_item.calls - << std::setw(data_width) << event_item.total_time - << std::setw(data_width) << event_item.min_time - << std::setw(data_width) << event_item.max_time - << std::setw(data_width) << event_item.ave_time << std::endl; - } +void DisableProfiler(EventSortingKey sorted_key, + const std::string& profile_path) { + PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, + "Can't disable profiling, since it's not starting."); + // Mark the profiling stop. + Mark("_stop_profiler_", nullptr); + + std::vector> all_events = GetAllEvents(); + ParseEvents(all_events, sorted_key); + ResetProfiler(); + DeviceTracer* tracer = GetDeviceTracer(); + if (g_state == ProfilerState::kAll && tracer && tracer->IsEnabled()) { + tracer->Disable(); + tracer->GenProfile(profile_path); } - std::cout << std::endl; + g_state = ProfilerState::kDisabled; } } // namespace platform diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index de9a5cc20d76bf84778e0933831f218abb66c465..b07427c8f6903e0100ca9a478881444d86501bcc 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include #include -#include +#include #include #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/profiler.pb.h" @@ -23,16 +23,16 @@ limitations under the License. */ namespace paddle { namespace platform { -enum EventKind { kMark, kPushRange, kPopRange }; +enum EventType { kMark, kPushRange, kPopRange }; class Event { public: // The DeviceContext is used to get the cuda stream. // If CPU profiling mode, can pass nullptr. - Event(EventKind kind, std::string name, uint32_t thread_id, + Event(EventType type, std::string name, uint32_t thread_id, const DeviceContext* dev_ctx); - std::string kind() const; + const EventType& type() const; std::string name() const { return name_; } uint32_t thread_id() const { return thread_id_; } bool has_cuda() const { return has_cuda_; } @@ -46,7 +46,7 @@ class Event { double CudaElapsedMs(const Event& e) const; private: - EventKind kind_; + EventType type_; std::string name_; uint32_t thread_id_; int64_t cpu_ns_; @@ -57,39 +57,6 @@ class Event { #endif }; -struct EventList { - constexpr static size_t kMB = 1024 * 1024; - constexpr static size_t kEventBlockSize = 16 * kMB; - constexpr static size_t kEventSize = sizeof(Event); - constexpr static size_t kEventAlign = alignof(Event); - constexpr static size_t kNumBlock = - kEventBlockSize / - ((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign); - - template - void Record(Args&&... args) { - if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) { - event_blocks.emplace_front(); - event_blocks.front().reserve(kNumBlock); - } - event_blocks.front().emplace_back(std::forward(args)...); - } - - std::vector Reduce() { - std::vector result; - for (auto& block : event_blocks) { - result.insert(result.begin(), std::make_move_iterator(block.begin()), - std::make_move_iterator(block.end())); - } - event_blocks.clear(); - return result; - } - - void Clear() { event_blocks.clear(); } - - std::forward_list> event_blocks; -}; - enum ProfilerState { kDisabled, // disabled state kCPU, // CPU profiling state @@ -136,16 +103,6 @@ struct RecordThread { // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. std::vector> GetAllEvents(); -// The information of each event given in the profiling report -struct EventItem { - std::string name; - int calls; - double total_time; - double min_time; - double max_time; - double ave_time; -}; - // Candidate keys to sort the profiling report enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve }; @@ -158,14 +115,5 @@ void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); -// Parse the event list and output the profiling report -void ParseEvents(std::vector>&, - EventSortingKey sorted_by = EventSortingKey::kDefault); - -// Print results -void PrintProfiler(std::vector>& events_table, - std::string& sorted_domain, const size_t name_width, - const size_t data_width); - } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler_test.cc b/paddle/fluid/platform/profiler_test.cc index 45cc271bb888fc3a07ecc5daea6b549cb88b6d21..61f467814ba4a24c8b73f1bc614cda0ab8c4debd 100644 --- a/paddle/fluid/platform/profiler_test.cc +++ b/paddle/fluid/platform/profiler_test.cc @@ -13,22 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/profiler.h" +#include #ifdef PADDLE_WITH_CUDA -#include "cuda_runtime.h" +#include #endif #include "gtest/gtest.h" TEST(Event, CpuElapsedTime) { using paddle::platform::Event; - using paddle::platform::EventKind; + using paddle::platform::EventType; - Event start_event(EventKind::kPushRange, "test", 0, nullptr); + Event start_event(EventType::kPushRange, "test", 0, nullptr); EXPECT_TRUE(start_event.has_cuda() == false); int counter = 0; while (counter != 1000) { counter++; } - Event stop_event(EventKind::kPopRange, "test", 0, nullptr); + Event stop_event(EventType::kPopRange, "test", 0, nullptr); EXPECT_GT(start_event.CpuElapsedMs(stop_event), 0); } @@ -38,16 +39,16 @@ TEST(Event, CudaElapsedTime) { using paddle::platform::CUDADeviceContext; using paddle::platform::CUDAPlace; using paddle::platform::Event; - using paddle::platform::EventKind; + using paddle::platform::EventType; DeviceContext* dev_ctx = new CUDADeviceContext(CUDAPlace(0)); - Event start_event(EventKind::kPushRange, "test", 0, dev_ctx); + Event start_event(EventType::kPushRange, "test", 0, dev_ctx); EXPECT_TRUE(start_event.has_cuda() == true); int counter = 0; while (counter != 1000) { counter++; } - Event stop_event(EventKind::kPopRange, "test", 0, dev_ctx); + Event stop_event(EventType::kPopRange, "test", 0, dev_ctx); EXPECT_GT(start_event.CudaElapsedMs(stop_event), 0); } #endif @@ -55,7 +56,7 @@ TEST(Event, CudaElapsedTime) { TEST(RecordEvent, RecordEvent) { using paddle::platform::DeviceContext; using paddle::platform::Event; - using paddle::platform::EventKind; + using paddle::platform::EventType; using paddle::platform::RecordEvent; using paddle::platform::ProfilerState; using paddle::platform::EventSortingKey; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index bd8446df6650f5fb1c62e5370fd48216dbf31e17..392404045578489014f2283b885c388d5a4586cf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -465,7 +465,8 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); - m.def("init_devices", &framework::InitDevices); + m.def("init_devices", + [](bool init_p2p) { framework::InitDevices(init_p2p); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); #ifdef PADDLE_WITH_CUDA diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 0fea6a80794a64abc2dbf1428d534840febcd450..586ec48477f085a14d2f15b265a95d596705694f 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -41,6 +41,6 @@ int main(int argc, char** argv) { paddle::memory::Used(paddle::platform::CUDAPlace(0)); #endif - paddle::framework::InitDevices(); + paddle::framework::InitDevices(true); return RUN_ALL_TESTS(); } diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f01d638efddd471d5667fded183b90c2d7d0a856..a5a3884750cce8cf19b92f1e5f131b50a18d3c97 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -85,6 +85,8 @@ def __bootstrap__(): import core import os + in_test = 'unittest' in sys.modules + try: num_threads = int(os.getenv('OMP_NUM_THREADS', '1')) except ValueError: @@ -109,8 +111,11 @@ def __bootstrap__(): core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) core.init_glog(sys.argv[0]) - core.init_devices() + # don't init_p2p when in unittest to save time. + core.init_devices(not in_test) +# TODO(panyx0718): Avoid doing complex initialization logic in __init__.py. +# Consider paddle.init(args) or paddle.main(args) layers.monkey_patch_variable() __bootstrap__() diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index c5b53902bca90ae2260a7cda43e6866f897233b3..57d4a50e913c0d2994c62600f4e479056ed4c306 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -22,221 +22,504 @@ from scipy.special import expit class TestExp(OpTest): def setUp(self): self.op_type = "exp" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': np.exp(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.exp(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Exp(TestExp): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSigmoid(OpTest): def setUp(self): self.op_type = "sigmoid" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': 1 / (1 + np.exp(-self.inputs['X']))} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = 1 / (1 + np.exp(-x)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out', max_relative_error=0.008) + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', max_relative_error=0.01) + + def init_dtype(self): + pass + + +class TestFP16Sigmoid(TestSigmoid): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestLogSigmoid(OpTest): def setUp(self): self.op_type = "logsigmoid" - self.inputs = { - 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': np.log(1 / (1 + np.exp(-self.inputs['X'])))} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = np.log(1 / (1 + np.exp(-x))) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.008) + def init_dtype(self): + pass + + +class TestFP16LogSigmoid(TestLogSigmoid): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestTanh(OpTest): def setUp(self): self.op_type = "tanh" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': np.tanh(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.tanh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Tanh(TestTanh): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestTanhShrink(OpTest): def setUp(self): self.op_type = "tanh_shrink" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32") - } - self.outputs = {'Out': self.inputs['X'] - np.tanh(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [10, 17]).astype(self.dtype) + out = x - np.tanh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.008) + def init_dtype(self): + pass + + +class TestFP16TanhShrink(TestTanhShrink): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestHardShrink(OpTest): def setUp(self): self.op_type = "hard_shrink" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + threshold = 0.5 + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) + out = np.copy(x) + out[(out >= -threshold) & (out <= threshold)] = 0 - self.inputs = {'X': x} self.attrs = {'lambda': threshold} - - t = np.copy(x) - t[(t >= -threshold) & (t <= threshold)] = 0 - self.outputs = {'Out': t} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.005) + def init_dtype(self): + pass + + +class TestFP16HardShrink(TestHardShrink): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSoftShrink(OpTest): def setUp(self): self.op_type = "softshrink" + self.dtype = np.float32 + self.init_dtype() + lambda_val = 0.1 + x = np.random.uniform(0.25, 10, [4, 4]).astype(self.dtype) + out = np.copy(x) + out = (out < -lambda_val) * (out + lambda_val) + (out > lambda_val) * ( + out - lambda_val) + self.attrs = {'lambda': lambda_val} - self.inputs = { - 'X': np.random.uniform(0.25, 10, [4, 4]).astype("float32") - } - y = np.copy(self.inputs['X']) - y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * ( - y - lambda_val) - self.outputs = {'Out': y} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16SoftShrink(TestSoftShrink): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSqrt(OpTest): def setUp(self): self.op_type = "sqrt" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': np.sqrt(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.sqrt(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Sqrt(TestSqrt): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestAbs(OpTest): def setUp(self): self.op_type = "abs" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) # Because we set delta = 0.005 in caculating numeric gradient, # if x is too small, such as 0.002, x_neg will be -0.003 # x_pos will be 0.007, so the numeric gradient is unaccurate. # we should avoid this x[np.abs(x) < 0.005] = 0.02 - self.inputs = {'X': x} - self.outputs = {'Out': np.abs(self.inputs['X'])} + out = np.abs(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Abs(TestAbs): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestCeil(OpTest): def setUp(self): self.op_type = "ceil" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - self.inputs = {'X': x} - self.outputs = {'Out': np.ceil(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) + out = np.ceil(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Ceil(TestCeil): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestFloor(OpTest): def setUp(self): self.op_type = "floor" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - self.inputs = {'X': x} - self.outputs = {'Out': np.floor(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) + out = np.floor(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Floor(TestFloor): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestCos(OpTest): def setUp(self): self.op_type = "cos" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - self.inputs = {'X': x} - self.outputs = {'Out': np.cos(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) + out = np.cos(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Cos(TestCos): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSin(OpTest): def setUp(self): self.op_type = "sin" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - self.inputs = {'X': x} - self.outputs = {'Out': np.sin(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) + out = np.sin(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Sin(TestSin): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestRound(OpTest): def setUp(self): self.op_type = "round" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - self.inputs = {'X': x} - self.outputs = {'Out': np.round(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) + out = np.round(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Round(TestRound): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestRelu(OpTest): def setUp(self): @@ -278,222 +561,463 @@ class TestFP16Relu(TestRelu): class TestBRelu(OpTest): def setUp(self): self.op_type = "brelu" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) t_min = 1.0 t_max = 4.0 # The same with TestAbs x[np.abs(x - t_min) < 0.005] = t_min + 0.02 x[np.abs(x - t_max) < 0.005] = t_max + 0.02 - - self.inputs = {'X': x} - self.attrs = {'t_min': t_min, 't_max': t_max} t = np.copy(x) t[t < t_min] = t_min t[t > t_max] = t_max + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'t_min': t_min, 't_max': t_max} self.outputs = {'Out': t} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.02) + def init_dtype(self): + pass + + +class TestFP16BRelu(TestBRelu): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestRelu6(OpTest): def setUp(self): self.op_type = "relu6" - x = np.random.uniform(-1, 1, [4, 10]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 10]).astype(self.dtype) threshold = 6.0 # The same with TestAbs x[np.abs(x) < 0.005] = 0.02 x[np.abs(x - threshold) < 0.005] = threshold + 0.02 + out = np.minimum(np.maximum(x, 0), threshold) - self.inputs = {'X': x} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.attrs = {'threshold': threshold} - self.outputs = { - 'Out': np.minimum(np.maximum(self.inputs['X'], 0), threshold) - } + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.02) + def init_dtype(self): + pass + + +class TestFP16Relu6(TestRelu6): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSoftRelu(OpTest): def setUp(self): self.op_type = "soft_relu" - x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-3, 3, [4, 4]).astype(self.dtype) threshold = 2.0 # The same reason with TestAbs x[np.abs(x - threshold) < 0.005] = threshold + 0.02 x[np.abs(x + threshold) < 0.005] = -threshold + 0.02 - self.inputs = {'X': x} - self.attrs = {'threshold': threshold} t = np.copy(x) t[t < -threshold] = -threshold t[t > threshold] = threshold - self.outputs = {'Out': np.log((np.exp(t) + 1))} + out = np.log((np.exp(t) + 1)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': threshold} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.02) + def init_dtype(self): + pass + + +class TestFP16SoftRelu(TestSoftRelu): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestELU(OpTest): def setUp(self): self.op_type = "elu" - x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-3, 3, [4, 4]).astype(self.dtype) alpha = 1. + out = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here self.inputs = {'X': x} self.attrs = {'alpha': alpha} - self.outputs = { - 'Out': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) - } + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.02) + def init_dtype(self): + pass + + +class TestFP16ELU(TestELU): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestReciprocal(OpTest): def setUp(self): self.op_type = "reciprocal" - self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} - self.outputs = {'Out': np.reciprocal(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np.reciprocal(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.01) + def init_dtype(self): + pass + + +class TestFP16Reciprocal(TestReciprocal): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestLog(OpTest): def setUp(self): self.op_type = "log" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': np.log(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.log(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Log(TestLog): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSquare(OpTest): def setUp(self): self.op_type = "square" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Out': np.square(self.inputs['X'])} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.square(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Square(TestSquare): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestPow(OpTest): def setUp(self): self.op_type = "pow" - self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np.power(x, 3) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.attrs = {'factor': 3.0} - self.outputs = {'Out': np.power(self.inputs['X'], 3)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.02) + def init_dtype(self): + pass + + +class TestFP16Pow(TestPow): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=5e-2) + class TestSTanh(OpTest): def setUp(self): self.op_type = "stanh" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) scale_a = 2.0 / 3.0 scale_b = 1.7159 + out = scale_b * np.tanh(x * scale_a) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} - self.outputs = {'Out': scale_b * np.tanh(self.inputs['X'] * scale_a)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16STanh(TestSTanh): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSoftplus(OpTest): def setUp(self): self.op_type = "softplus" - self.inputs = { - 'X': np.random.uniform(-1, 1, [11, 17]).astype("float64") - } - self.outputs = {'Out': np.log(1 + np.exp(self.inputs['X']))} + self.dtype = np.float64 + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = np.log(1 + np.exp(x)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Softplus(TestSoftplus): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSoftsign(OpTest): def setUp(self): self.op_type = "softsign" - self.inputs = { - 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X'])) - } + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = np.divide(x, 1 + np.abs(x)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Softsign(TestSoftsign): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestThresholdedRelu(OpTest): def setUp(self): self.op_type = "thresholded_relu" + self.dtype = np.float32 + self.init_dtype() + threshold = 0.25 self.relative_error = 0.005 - X = np.random.uniform(-1, 1, [11, 17]).astype("float32") + X = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) # Same reason as TestAbs X[np.abs(X - threshold) < self.relative_error] = threshold + 0.2 + out = (X > threshold) * X - self.inputs = {'X': X} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)} self.attrs = {'threshold': threshold} - self.outputs = {'Out': (X > threshold) * X} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=self.relative_error) + def init_dtype(self): + pass + + +class TestFP16ThresholdedRelu(TestThresholdedRelu): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestHardSigmoid(OpTest): def setUp(self): self.op_type = "hard_sigmoid" + self.dtype = np.float32 + self.init_dtype() + self.relative_error = 0.002 X = np.random.uniform(-5, 5, [2, 2]).astype("float32") @@ -502,7 +1026,6 @@ class TestHardSigmoid(OpTest): lower_threshold = -offset / slope upper_threshold = (1 - offset) / slope - self.inputs = {'X': X} # Same reason as TestAbs X[np.abs(X - lower_threshold) < self.relative_error] = \ lower_threshold + 0.2 @@ -510,29 +1033,70 @@ class TestHardSigmoid(OpTest): upper_threshold - 0.2 temp = X * slope + offset - self.outputs = {'Out': np.maximum(0.0, np.minimum(1.0, temp))} + out = np.maximum(0.0, np.minimum(1.0, temp)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.002) + def init_dtype(self): + pass + + +class TestFP16HardSigmoid(TestHardSigmoid): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestSwish(OpTest): def setUp(self): self.op_type = "swish" - X = np.random.uniform(0.1, 1, [11, 17]).astype("float32") - self.inputs = {'X': X} - self.attrs = {'beta': 2.3} - self.outputs = {'Out': X * expit(self.attrs['beta'] * X)} + self.dtype = np.float32 + self.init_dtype() + + X = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + beta = 2.3 + out = X * expit(beta * X) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)} + self.attrs = {'beta': beta} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.008) + def init_dtype(self): + pass + + +class TestFP16Swish(TestSwish): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + #--------------------test MKLDNN-------------------- class TestMKLDNNReluDim2(TestRelu):