diff --git a/.gitignore b/.gitignore index 9e3a0b499f9f42856429f3a42bef313ea3df3699..b92bb9cc129659fa502b4a9b55548992412e5429 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/ python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/ python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/ *.DS_Store +*.vs build/ build_doc/ *.user @@ -15,6 +16,7 @@ build_doc/ .cproject .pydevproject .settings/ +CMakeSettings.json Makefile .test_env/ third_party/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 920c20d6f813c14df8e02593b1c4a5a13cc11ef0..68447727118a91a2a8c0d06404353c7ccb734c6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,11 +204,12 @@ include(external/snappy) # download snappy include(external/snappystream) include(external/threadpool) -set(WITH_ANAKIN OFF CACHE STRING "Disable Anakin first, will add it later." FORCE) if(WITH_GPU) include(cuda) include(tensorrt) include(external/anakin) +elseif() + set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in GPU only now." FORCE) endif() include(cudnn) # set cudnn libraries, must before configure diff --git a/cmake/configure.cmake b/cmake/configure.cmake index c35096e09b5685ee30f20648e7dd461f71e0b1c4..d14162e0a662afe63152bfc2132e5dfd54f5a86c 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -56,6 +56,10 @@ if(NOT CMAKE_CROSSCOMPILING) set(SIMD_FLAG ${SSE3_FLAG}) endif() endif() +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif(UNIX AND NOT APPLE) if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) @@ -104,6 +108,10 @@ if(WITH_GPU) if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) message(FATAL_ERROR "Anakin needs CUDNN >= 7.0 to compile") endif() + set(ENV{CUDNN_INCLUDE_DIR} ${CUDNN_INCLUDE_DIR}) + set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY}) + message(STATUS "cudnn include header is ${CUDNN_INCLUDE_DIR}/cudnn.h") + message(STATUS "cudnn library is ${CUDNN_LIBRARY}") endif() elseif(WITH_AMD_GPU) add_definitions(-DPADDLE_WITH_HIP) diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 403873a51017b9bb7c888bed77d89ca9b15b68d2..455ef91ac5e9d98def959256d15ae836bcf0befc 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -35,9 +35,8 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS ExternalProject_Add( extern_anakin ${EXTERNAL_PROJECT_LOG_ARGS} - # TODO(luotao): use PaddlePaddle/Anakin later - GIT_REPOSITORY "https://github.com/luotao1/Anakin" - GIT_TAG "3957ae9263eaa0b1986758dac60a88852afb09be" + GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin" + GIT_TAG "04256ba78fa3da0beb74e8036c8efd68c12824d6" PREFIX ${ANAKIN_SOURCE_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DUSE_GPU_PLACE=YES diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c020ff45ad3f3a72bf8a88622df333c1765a3d21..e963902a50200b785284e8f233fcca1abf459140 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -155,10 +155,11 @@ paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 83867e0a2c198f265cb36be3a71546796dfbec67..a72e27d651d0591815a9d93354d2aea8aa216de6 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -60,7 +60,7 @@ cc_library(paddle_inference_tensorrt_subgraph_engine inference_api_test(test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_engine_tester.cc ARGS test_word2vec) endif() -if (WITH_ANAKIN) # only needed in CI +if (WITH_ANAKIN AND WITH_GPU) # only needed in CI # compile the libinference_anakin_api.a and anakin.so. nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) #nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 27487b396ccf63d962defa6b270063ccb409164e..d3a7ceed466a9b5e4d773f1531d198adff97eac2 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -26,6 +26,8 @@ namespace plat = paddle::platform; act_type##_grad, ops::ActivationGradKernel>, \ ops::ActivationGradKernel>); + ops::grad_functor>, \ + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 912415192659dc004f54a76e9cd1a20581d512a6..48f3b5a5bc06fbc211895a1a6d1521cfd97e0086 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - const Out out_conj = Eigen::numext::conj(out); - dx.device(d) = static_cast(0.5) * dout / out_conj; + dx.device(d) = static_cast(0.5) * dout / out; } }; @@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * - x.pow(static_cast(factor - static_cast(1))); + x.pow(static_cast(factor) - static_cast(1)); } }; @@ -863,10 +862,11 @@ struct SwishGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + T b = static_cast(beta); auto temp1 = static_cast(1) / - (static_cast(1) + (static_cast(-beta) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (beta * out)); - dx.device(d) = dout * ((beta * out) + temp2); + (static_cast(1) + (static_cast(-b) * x).exp()); + auto temp2 = temp1 * (static_cast(1) - (b * out)); + dx.device(d) = dout * ((b * out) + temp2); } }; diff --git a/paddle/fluid/operators/assign_value_op.cu.cc b/paddle/fluid/operators/assign_value_op.cu.cc index 08bfde5dc92de9c675e5b9b85f8e65a3bab8631c..0ff174b3884df63d54d6486b017cc1a15ab23103 100644 --- a/paddle/fluid/operators/assign_value_op.cu.cc +++ b/paddle/fluid/operators/assign_value_op.cu.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/assign_value_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel, - ops::AssignValueKernel); + ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 22cbf680c0670552fb014043c69fcadc56863529..59bfe8f61d8ebb530ba617006650c0ef9215e2a6 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -39,6 +39,27 @@ using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; +template +// bool EnableFp16(const T& dummy, const DeviceContext& dev_ctx, +bool EnableFp16(const DeviceContext& dev_ctx, + cudnnConvolutionDescriptor_t cudnn_conv_desc) { +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + // Tensor core is supported since the volta GPU and + // is only enabled when input and filter data are float16 + if (dev_ctx.GetComputeCapability() >= 70 && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + return true; + } else { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + } +#endif + return false; +} + template class CUDNNConvOpKernel : public framework::OpKernel { public: @@ -128,27 +149,14 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Tensor core is supported since the volta GPU and - // is only enabled when input and filter data are float16 - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - // Currently tensor core is only enabled using this algo + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } else { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); } -#endif // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( @@ -288,6 +296,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( @@ -307,6 +318,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -362,7 +376,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, @@ -370,4 +385,5 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel) diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index 30dbd5bd3d39dd2992c3dd91364003bb7715a2eb..65fd3a5dbc9ffed4c5d1114346fcc0660c183dae 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; using CUDACtx = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel); -REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, - ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel); + ops::CrossEntropyOpKernel, + ops::CrossEntropyOpKernel); +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad, ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index dfff518f170b56d180b6883c363effb8dbd677b6..f9f5c66d34fa1d73db00173e493f9953b8579518 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/fluid/operators/elementwise_div_op.cu b/paddle/fluid/operators/elementwise_div_op.cu index 588d1f7420241ba1697e5141e4e4a2870f2dc87c..4cc7ba0f43c6031bf4a27222a17eca84bad5a668 100644 --- a/paddle/fluid/operators/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise_div_op.cu @@ -14,19 +14,24 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_div_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); + plat::float16>); diff --git a/paddle/fluid/operators/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise_mul_op.cu index 2fb1b4bee689c059625e3dbd59f80c541ace83a0..350d43168dea7e88127b0d28d663e680458e1dba 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise_mul_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_mul_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index bc3e95e904f8b6c2cdd2ae6685bf67580178e6b6..7223a972d23119c8ef93fb49bfe42922cc14571d 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -350,7 +350,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - T val = 0; + T val(0); do { int x_offset = i * w + j; @@ -418,7 +418,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( int tid = threadIdx.x; int j = blockIdx.x; - T val = 0; + T val(0); int ttid = tid; while (true) { diff --git a/paddle/fluid/operators/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise_sub_op.cu index 8709f686f9af1bf4dacbc2dfc3e2d5dcc1c59b9a..ff3f6f8a2cb542c2fb6b43d539f6413b39250992 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise_sub_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_sub_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_sub, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 130f18dde4f979a6a9925ede9cbf745fcec14d48..862249269eaecdac262a691c884ea59f89f54061 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -12,48 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -class FillConstantInferShape : public framework::InferShapeBase { +class FillConstantOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); - auto &shape = ctx->Attrs().Get>("shape"); + auto& shape = ctx->Attrs().Get>("shape"); ctx->SetOutputDim("Out", framework::make_ddim(shape)); } -}; - -class FillConstantOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - auto data_type = - static_cast(Attr("dtype")); - auto value = Attr("value"); - auto force_cpu = Attr("force_cpu"); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - out.Resize(framework::make_ddim(Attr>("shape"))); - if (force_cpu) { - auto cpu = platform::CPUPlace(); - out.mutable_data(cpu, framework::ToTypeIndex(data_type)); - } else { - out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); - } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - math::set_constant(dev_ctx, &out, value); + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + static_cast(ctx.Attr("dtype")), + ctx.device_context()); } }; @@ -87,6 +67,11 @@ Fill up a variable with specified constant value. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, - ops::FillConstantInferShape, ops::FillConstantOpMaker, +REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker, paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..51ccaefa4338dfa18d26441a59d5fed2b9fa0c39 --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b2a2a7b2faedf9b94e01ed908ff39749973be1df --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + auto value = ctx.Attr("value"); + auto force_cpu = ctx.Attr("force_cpu"); + auto* out = ctx.Output("Out"); + out->Resize(framework::make_ddim(ctx.Attr>("shape"))); + if (force_cpu) { + auto cpu = platform::CPUPlace(); + out->mutable_data(cpu, framework::ToTypeIndex(data_type)); + } else { + out->mutable_data(ctx.GetPlace(), framework::ToTypeIndex(data_type)); + } + + math::set_constant(ctx.template device_context(), out, + value); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 925dc19061e2196a40411f415eb6e5ad59ab52ff..352a17c927bc70bdd6e4307951f0e0ac3d10ac2d 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -69,7 +70,6 @@ class FillOp : public framework::OperatorBase { framework::VisitDataType( dtype, FillOpVisitor(&tensor, Attr>("value"))); - if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 7784856417e579fd43f79fa331d46df8af6c36b8..b4907237954ba478197d5ca8bdcbc3e1915e9dcf 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -60,6 +61,7 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(gaussian_random, paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 0de58d5fddd84d33f708c4c73e5a19dc2fe8a86b..58b85abf822741905a4e9547823b6cdbe645d39a 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -15,11 +15,25 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { +template +HOSTDEVICE T log(const T& val) { + return std::log(val); +} + +template <> +HOSTDEVICE platform::float16 log(const platform::float16& val) { + // strage bug, hlog is not exists. + return static_cast(0); + // half tmp = static_cast(val); + // return static_cast(hlog(tmp)); +} + namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, @@ -35,12 +49,12 @@ template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int class_num) { int tid = threadIdx.x; - T val = 0; + T val(0); int idx = blockIdx.x * class_num + tid; int end = blockIdx.x * class_num + class_num; for (; idx < end; idx += blockDim.x) { - val += math::TolerableValue()(std::log(X[idx])) * label[idx]; + val += math::TolerableValue()(log(X[idx])) * label[idx]; } val = paddle::platform::reduceSum(val, tid, blockDim.x); @@ -84,6 +98,8 @@ class CrossEntropyFunctor { template class CrossEntropyFunctor; template class CrossEntropyFunctor; +template class CrossEntropyFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index adc5b3fe47cd3bf524eb56747b6bd51e345a2eb6..2e4e4781c2eee1d9a0fc6760093a424ab3d5eb9d 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -33,6 +35,21 @@ struct TolerableValue { } }; +// float16 value clip behave different. +using paddle::platform::float16; +using paddle::platform::isfinite; +template <> +struct TolerableValue { + HOSTDEVICE float16 operator()(const float16& x) const { + if (isfinite(x)) + return x; + else if (x > static_cast(0)) + return std::numeric_limits::max(); + else + return std::numeric_limits::min(); + } +}; + template class CrossEntropyFunctor { public: diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a92762c7fea865fad2c7784736cce93a8af21892..00dbfc11a239da70ec81e3498d2f4d5e5bf1c63f 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -76,6 +77,7 @@ struct SelectedRowsAdd { template struct SelectedRowsAdd; template struct SelectedRowsAdd; +template struct SelectedRowsAdd; namespace { template @@ -120,7 +122,7 @@ struct SelectedRowsAddTensor { auto* out_data = output->data(); SetConstant functor; - functor(context, output, 0.0); + functor(context, output, static_cast(0)); const int block_size = 256; dim3 threads(block_size, 1); @@ -138,6 +140,8 @@ struct SelectedRowsAddTensor { template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; template struct SelectedRowsAddTo { @@ -177,6 +181,8 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; namespace { template @@ -229,6 +235,8 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { @@ -276,7 +284,7 @@ struct MergeAdd { context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); @@ -300,6 +308,7 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 3effe776258cb541dbba32f63eda457d917011f4..785c4baecbf056d08930f4bb704aec067a2db4a2 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -94,12 +94,15 @@ void SoftmaxGradCUDNNFunctor::operator()( template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; +template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 91e0ab28efc21d4376524c8ecf66b429d51d8847..07aa23754f9786c56c0be14c2a71d5290d2cccf7 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanKernel, - ops::MeanKernel); + ops::MeanKernel, + ops::MeanKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel); + ops::MeanGradKernel, + ops::MeanGradKernel); diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h index 362e9f9ae8b2f0f77198e3f3939211ae1117b27b..a41d50ae0b99797800078184f7ffeb366367f493 100644 --- a/paddle/fluid/operators/mean_op.h +++ b/paddle/fluid/operators/mean_op.h @@ -55,7 +55,7 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = static_cast(IG->numel()); - Eigen::DSizes bcast(ig_size); + Eigen::DSizes bcast(static_cast(ig_size)); EigenVector::Flatten(*IG).device( *context.template device_context().eigen_device()) = diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 81f3e42bf412fa4d2cb48405f2f8ee49b6aa0b67..6c5a83c6a50c463502171f09bbf18e17e43917b5 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -20,6 +20,7 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, ops::MulKernel, ops::MulKernel); -REGISTER_OP_CUDA_KERNEL(mul_grad, - ops::MulGradKernel, - ops::MulGradKernel); +REGISTER_OP_CUDA_KERNEL( + mul_grad, ops::MulGradKernel, + ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 31f083565fddee66aea1485ed71f41b6199f4502..9fdbee818a217842e47c8ab11b84c6d5513ad219 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -174,7 +174,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, @@ -182,4 +183,5 @@ REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index db040509bc08c3f6ad031c5b97c93574e31337e0..23d9ea88f6701f9f9e5e02948e996878a849ddd6 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { + std::string mode = ctx->Attrs().Get("mode"); + + auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); - PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, - "Size of weight Alpha must be one."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (mode == "all") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, + "For mode 'all', size of weight Alpha must be one."); + } else if (mode == "channel") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1], + "For channel-wise mode, size of weight Alpha must be " + "equal to the number of channels, should be %d", + x_dim[1]); + } else if (mode == "element") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim), + "For element-wise mode, size of weight Alpha must be " + "equal to the number of input, should be %d", + product(x_dim)); + } else { + PADDLE_THROW("Unkown mode %s", mode); + } + ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } }; class PReluOpMaker : public framework::OpProtoAndCheckerMaker { @@ -44,9 +67,7 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of prelu operator."); AddComment(R"DOC( PRelu Operator. - The equation is: - $$ f(x) = \begin{cases} @@ -54,11 +75,15 @@ f(x) = x, \qquad \text{if} \ x >= 0 \end{cases} $$ - The input `X` can carry the LoD (Level of Details) information, or not. And the output shares the LoD information with input `X`. - +There are modes: + all: all elements share same weight + channel: elements in a channel share same weight + element: each element has a weight )DOC"); + AddAttr("mode", "The mode for inputs to share weights.") + .SetDefault("all"); } }; @@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->SetOutputDim(framework::GradVarName("Alpha"), - ctx->GetInputDim("Alpha")); + auto x_grad_name = framework::GradVarName("X"); + auto alpha_grad_name = framework::GradVarName("Alpha"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + if (ctx->HasOutput(alpha_grad_name)) { + ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu deleted file mode 100644 index 37d934a29046be04a1721b7330c813f663f61aed..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/prelu_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/prelu_op.h" - -REGISTER_OP_CUDA_KERNEL( - prelu, - paddle::operators::PReluKernel); -REGISTER_OP_CUDA_KERNEL(prelu_grad, - paddle::operators::PReluGradKernel< - paddle::platform::CUDADeviceContext, float>); diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index a6197d354833a2f4173003ad2a970c487ad9a65b..f9076cbc678534fd5490fa0d7adeac0e50909a39 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,32 +10,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" - namespace paddle { namespace operators { using Tensor = framework::Tensor; using platform::Transform; -template -class PReluFunctor { - public: - explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& x) const { - if (x > 0) - return x; - else - return x * (*alpha_); - } - - private: - const T* alpha_; -}; - template class PReluKernel : public framework::OpKernel { public: @@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel { const T* x_ptr = x->data(); T* o_ptr = out->mutable_data(context.GetPlace()); - auto* alpha_ptr = alpha->data(); + const T* alpha_ptr = alpha->data(); + std::string mode = context.Attr("mode"); int numel = x->numel(); - - Transform trans; - trans(context.template device_context(), x_ptr, - x_ptr + numel, o_ptr, PReluFunctor(alpha_ptr)); - } -}; - -template -class PReluGradFunctor { - public: - explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& out, const T& dout) const { - if (out > 0) - return dout; - else - return dout * (*alpha_); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i]; + } + } } - - private: - const T* alpha_; }; template class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); auto* dx = context.Output(framework::GradVarName("X")); auto* dout = context.Input(framework::GradVarName("Out")); - + auto* dalpha = context.Output(framework::GradVarName("Alpha")); auto* out = context.Input("Out"); auto* alpha = context.Input("Alpha"); - auto* alpha_ptr = alpha->data(); - - T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* alpha_ptr = alpha->data(); + const T* x_ptr = x->data(); const T* dout_ptr = dout->data(); const T* out_ptr = out->data(); - int numel = dx->numel(); - - Transform trans; - trans(context.template device_context(), out_ptr, - out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor(alpha_ptr)); - - // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready + std::string mode = context.Attr("mode"); + int numel = x->numel(); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (dx) { + T* dx_ptr = dx->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dx_ptr[i] = + out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i]; + } + } + } + + index = 0; + if (dalpha) { + T* dalpha_ptr = dalpha->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } + } + + // TODO(Guanzhong): add GPU kernels } }; diff --git a/paddle/fluid/operators/scale_op.cu b/paddle/fluid/operators/scale_op.cu index 04c802da12958a53626f533833c2709110531136..d266867046334f95eaaf4b7a9acb3fec20f1e439 100644 --- a/paddle/fluid/operators/scale_op.cu +++ b/paddle/fluid/operators/scale_op.cu @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" +#include "paddle/fluid/platform/float16.h" +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( scale, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel); + int64_t>, + paddle::operators::ScaleKernel); diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index d29947b55e751a3e7993f765198364f4debe2472..181bb1af5cce7fad228e61e1a76ed66a9bd61b3e 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -35,7 +35,7 @@ class ScatterOpKernel : public framework::OpKernel { auto *Out = ctx.Output("Out"); // In place output: Out = X, Out[Ids] += Updates - Out->ShareDataWith(*X); + framework::TensorCopySync(*X, ctx.GetPlace(), Out); // Apply ScatterUpdate: Out[index] += Updates[:] ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); } @@ -53,7 +53,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *dOut = ctx.Input(framework::GradVarName("Out")); // In place gradient: dX = dO - dX->ShareDataWith(*dOut); + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Ids] CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e999621b10799b5163f326bc4b66a437e6..c2d45c3d2ef82683352afe0e72f0330f7cd753f6 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -78,4 +78,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 5fb4f011d9b47cebc4a23bcce47eada825263343..19359b7eef5126d84f0707d39095a74ae4561186 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SoftmaxKernel); REGISTER_OP_CUDA_KERNEL( softmax_grad, ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index 89bcd1bbc86dc29cb7b98cbef3057a8f98c74555..db4c2d6c115f04b436db00854ca4b02fea09866b 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -11,10 +11,13 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( sum, ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 49a4afb3a8a19c97e844e66477c6288772ece807..dda6772796c821ffb813e73da0c34370e5339001 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -46,7 +46,7 @@ class SumKernel : public framework::OpKernel { if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, - 0.0); + static_cast(0)); } math::SelectedRowsAddToTensor functor; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2d7ea66ad434c42b54522432095ce29..5fc0784f665f9f4a4422ca9b70f7dc6001833a8f 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -11,16 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using paddle::platform::float16; template struct Pair { @@ -32,6 +35,11 @@ struct Pair { id = id; } + __device__ __forceinline__ void clear() { + v = -INFINITY; + id = -1; + } + __device__ __forceinline__ void operator=(const Pair& in) { v = in.v; id = in.id; @@ -53,6 +61,12 @@ struct Pair { int64_t id; }; +template <> +__device__ __forceinline__ void Pair::clear() { + v = platform::raw_uint16_to_float16(0x400); + id = -1; +} + template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { @@ -150,7 +164,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } } if (!(*is_empty)) { @@ -160,7 +174,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, } *max = topk[MaxLength - 1]; - if ((*max).v == -1) *is_empty = true; + if ((*max).v == static_cast(-1)) *is_empty = true; *beam = 0; } } @@ -181,7 +195,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - *beam) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].set(std::numeric_limits::min(), -1); } } if (!(*is_empty)) { @@ -273,7 +287,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, bool firststep = true; for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } while (k) { ThreadGetTopK(topk, &beam, k, @@ -325,5 +339,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index e1c7323a30233f4ec4f60e46aa6088ee6d8601b7..2b8039a0c1bea07402435958608ea035ba862c90 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -11,10 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -36,6 +40,11 @@ struct UniformGenerator { } }; +template +struct CastFunctor { + HOSTDEVICE V operator()(const T& a) { return static_cast(a); } +}; + // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -66,18 +75,50 @@ class GPUUniformRandomKernel : public framework::OpKernel { T max = static_cast(context.Attr("max")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); + if (out_var->IsType() && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + framework::Tensor master_copy_tensor; + master_copy_tensor.Resize(tensor->dims()); + float* master_copy_tensor_data = + master_copy_tensor.mutable_data(context.GetPlace()); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(master_copy_tensor_data), + UniformGenerator(static_cast(min), + static_cast(max), seed)); + platform::Transform trans; + auto* in_begin = master_copy_tensor.data(); + auto* in_end = in_begin + master_copy_tensor.numel(); + auto* out_begin = tensor->mutable_data(context.GetPlace()); + trans(context.template device_context(), + in_begin, in_end, out_begin, CastFunctor()); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } + if (VLOG_IS_ON(5)) { + framework::Tensor cpu_tensor; + framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + auto& dev_ctx = + *platform::DeviceContextPool::Instance().Get(context.GetPlace()); + dev_ctx.Wait(); + auto x = framework::EigenVector::Flatten(cpu_tensor); + VLOG(5) << "The Uniform output " << x; + } } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + uniform_random, paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL( + uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 25900811509aee8b37fdaf09cf902ea2ae3eee57..9cdcb87df5dd1669066c204c86c269973df506f1 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -97,10 +97,11 @@ if(APPLE) if(NOT INSTALL_NAME_TOOL_EXECUTABLE) message(FATAL_ERROR "install_name_tool not found, please check.\n") endif() -else(APPLE) +endif() +if(LINUX) find_program(PATCHELF_EXECUTABLE patchelf) if(NOT PATCHELF_EXECUTABLE) message(FATAL_ERROR "patchelf not found, please install it.\n" "For Ubuntu, the command is: apt-get install -y patchelf.") endif() -endif(APPLE) +endif(LINUX) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c75e7eeb4384f28ca0dd95e3b79b7de5a3031351..be852b67119182cc817495b5e993c872cb9a88bf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -112,6 +112,7 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'prelu', 'flatten', ] @@ -5089,7 +5090,7 @@ def random_crop(x, shape, seed=None): return out -def log(x): +def log(x, name=None): """ Calculates the natural log of the given input tensor, element-wise. @@ -5099,6 +5100,8 @@ def log(x): Args: x (Variable): Input tensor. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. Returns: Variable: The natural log of the input tensor computed element-wise. @@ -5116,7 +5119,7 @@ def log(x): return out -def relu(x): +def relu(x, name=None): """ Relu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, y = max(0, x), is applied to @@ -5128,6 +5131,8 @@ def relu(x): Args: x (Variable): The input tensor. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. Returns: Variable: The output tensor with the same shape as input. @@ -5364,6 +5369,59 @@ def rank_loss(label, left, right, name=None): return out +def prelu(x, mode, param_attr=None, name=None): + """ + Equation: + + y = \max(0, x) + alpha \min(0, x) + + Args: + x (Variable): The input tensor. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight (alpha). + mode (string): The mode for weight sharing + all: all elements share same weight + channel:elements in a channel share same weight + element:each element has a weight + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The output tensor with the same shape as input. + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") + mode = 'channel' + output = fluid.layers.prelu(x,mode) + """ + helper = LayerHelper('prelu', **locals()) + if mode not in ['all', 'channel', 'element']: + raise ValueError('mode should be one of all, channel, element.') + alpha_shape = [1] + if mode == 'channel': + alpha_shape = [1, x.shape[1], 1, 1] + elif mode == 'element': + alpha_shape = x.shape + dtype = helper.input_dtype(input_param_name='x') + alpha = helper.create_parameter( + attr=param_attr, + shape=alpha_shape, + dtype='float32', + is_bias=False, + default_initializer=Constant(1.0)) + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="prelu", + inputs={"X": x, + 'Alpha': alpha}, + attrs={"mode": mode}, + outputs={"Out": out}) + return out + + def flatten(x, axis=1, name=None): """ **Flatten layer** diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 38a138a8faee7746d9e7630d39206066634956f8..07fd0575d333dacf309620a883e4052c6126739f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -21,6 +21,7 @@ import paddle.fluid.nets as nets from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.param_attr import ParamAttr import decorators +from paddle.fluid.initializer import Constant class TestBook(unittest.TestCase): @@ -485,6 +486,20 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_prelu(self): + program = Program() + with program_guard(program): + input = layers.data( + name="input", shape=[5, 200, 100, 100], dtype="float32") + mode = 'channel' + out = layers.prelu( + input, + mode, + param_attr=ParamAttr(initializer=Constant(1.0)), + name='prelu') + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index ae19a553bb826002c562c15ee07759391d51b4d8..cb7de3fc93c0379ea50c88044876d6a8ee617a69 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -20,30 +20,58 @@ from op_test import OpTest class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" - x_np = np.random.normal(size=(10, 10)).astype("float32") - - for pos, val in np.ndenumerate(x_np): - # Since zero point in prelu is not differentiable, avoid randomize - # zero. - while abs(val) < 1e-3: - x_np[pos] = np.random.normal() - val = x_np[pos] - - x_np_sign = np.sign(x_np) - x_np = x_np_sign * np.maximum(x_np, .005) - alpha_np = np.array([.1], dtype="float32") - self.inputs = {'X': x_np, 'Alpha': alpha_np} + self.initTestCase() + x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") + + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + x_np[np.abs(x_np) < 0.005] = 0.02 + + if self.attrs == {'mode': "all"}: + alpha_np = np.random.rand(1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + elif self.attrs == {'mode': "channel"}: + alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + else: + alpha_np = np.random.rand(*x_np.shape).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.inputs['Alpha'] assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} + def initTestCase(self): + self.attrs = {'mode': "channel"} + def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X', 'Alpha'], 'Out') + + def test_check_grad_ignore_x(self): + self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_alpha(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) + + +class TestCase1(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "all"} + + +class TestCase2(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "channel"} + + +class TestCase3(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "element"} if __name__ == "__main__":