From bf3c34960f2a59a2616957f8fb4107b2ac7aa02b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 16 Aug 2018 11:00:55 +0800 Subject: [PATCH] "cherry picked operators changes" (#12184) * "cherry picked operators changes" * "remove duplicated code" * "add constant setter" * "add get expected kernel" * "fix ci" * "add fill constant" --- paddle/fluid/operators/activation_op.cu | 4 +- paddle/fluid/operators/activation_op.h | 12 ++-- paddle/fluid/operators/assign_value_op.cu.cc | 5 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 56 +++++++++++------- paddle/fluid/operators/cross_entropy_op.cu | 12 ++-- paddle/fluid/operators/elementwise_add_op.cu | 3 +- paddle/fluid/operators/elementwise_div_op.cu | 9 ++- paddle/fluid/operators/elementwise_mul_op.cu | 8 ++- .../fluid/operators/elementwise_op_function.h | 4 +- paddle/fluid/operators/elementwise_sub_op.cu | 8 ++- paddle/fluid/operators/fill_constant_op.cc | 53 ++++++----------- paddle/fluid/operators/fill_constant_op.cu.cc | 26 ++++++++ paddle/fluid/operators/fill_constant_op.h | 48 +++++++++++++++ paddle/fluid/operators/fill_op.cc | 2 +- paddle/fluid/operators/gaussian_random_op.cu | 2 + paddle/fluid/operators/math/cross_entropy.cu | 20 ++++++- paddle/fluid/operators/math/cross_entropy.h | 17 ++++++ .../operators/math/selected_rows_functor.cu | 13 +++- paddle/fluid/operators/math/softmax.cu | 3 + paddle/fluid/operators/mean_op.cu | 10 ++-- paddle/fluid/operators/mean_op.h | 2 +- paddle/fluid/operators/mul_op.cu.cc | 7 ++- paddle/fluid/operators/pool_cudnn_op.cu.cc | 6 +- paddle/fluid/operators/scale_op.cu | 6 +- paddle/fluid/operators/softmax_cudnn_op.cu.cc | 3 +- paddle/fluid/operators/softmax_op.cu.cc | 3 +- paddle/fluid/operators/sum_op.cu | 5 +- paddle/fluid/operators/sum_op.h | 2 +- paddle/fluid/operators/top_k_op.cu | 28 +++++++-- paddle/fluid/operators/uniform_random_op.cu | 59 ++++++++++++++++--- 30 files changed, 328 insertions(+), 108 deletions(-) create mode 100644 paddle/fluid/operators/fill_constant_op.cu.cc create mode 100644 paddle/fluid/operators/fill_constant_op.h diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 27487b396cc..d3a7ceed466 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -26,6 +26,8 @@ namespace plat = paddle::platform; act_type##_grad, ops::ActivationGradKernel>, \ ops::ActivationGradKernel>); + ops::grad_functor>, \ + ops::ActivationGradKernel>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 91241519265..48f3b5a5bc0 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - const Out out_conj = Eigen::numext::conj(out); - dx.device(d) = static_cast(0.5) * dout / out_conj; + dx.device(d) = static_cast(0.5) * dout / out; } }; @@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(factor) * - x.pow(static_cast(factor - static_cast(1))); + x.pow(static_cast(factor) - static_cast(1)); } }; @@ -863,10 +862,11 @@ struct SwishGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + T b = static_cast(beta); auto temp1 = static_cast(1) / - (static_cast(1) + (static_cast(-beta) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (beta * out)); - dx.device(d) = dout * ((beta * out) + temp2); + (static_cast(1) + (static_cast(-b) * x).exp()); + auto temp2 = temp1 * (static_cast(1) - (b * out)); + dx.device(d) = dout * ((b * out) + temp2); } }; diff --git a/paddle/fluid/operators/assign_value_op.cu.cc b/paddle/fluid/operators/assign_value_op.cu.cc index 08bfde5dc92..0ff174b3884 100644 --- a/paddle/fluid/operators/assign_value_op.cu.cc +++ b/paddle/fluid/operators/assign_value_op.cu.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/assign_value_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel, - ops::AssignValueKernel); + ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 22cbf680c06..59bfe8f61d8 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -39,6 +39,27 @@ using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; +template +// bool EnableFp16(const T& dummy, const DeviceContext& dev_ctx, +bool EnableFp16(const DeviceContext& dev_ctx, + cudnnConvolutionDescriptor_t cudnn_conv_desc) { +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + // Tensor core is supported since the volta GPU and + // is only enabled when input and filter data are float16 + if (dev_ctx.GetComputeCapability() >= 70 && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); + return true; + } else { + PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + } +#endif + return false; +} + template class CUDNNConvOpKernel : public framework::OpKernel { public: @@ -128,27 +149,14 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo; auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Tensor core is supported since the volta GPU and - // is only enabled when input and filter data are float16 - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - // Currently tensor core is only enabled using this algo + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } else { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); } -#endif // get workspace size able to allocate CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( @@ -288,6 +296,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( @@ -307,6 +318,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } else { filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } + if (EnableFp16(dev_ctx, cudnn_conv_desc)) { + filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -362,7 +376,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, @@ -370,4 +385,5 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel) diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu index 30dbd5bd3d3..65fd3a5dbc9 100644 --- a/paddle/fluid/operators/cross_entropy_op.cu +++ b/paddle/fluid/operators/cross_entropy_op.cu @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; using CUDACtx = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, - ops::CrossEntropyOpKernel); -REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, - ops::CrossEntropyGradientOpKernel, - ops::CrossEntropyGradientOpKernel); + ops::CrossEntropyOpKernel, + ops::CrossEntropyOpKernel); +REGISTER_OP_CUDA_KERNEL( + cross_entropy_grad, ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index dfff518f170..f9f5c66d34f 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/fluid/operators/elementwise_div_op.cu b/paddle/fluid/operators/elementwise_div_op.cu index 588d1f74202..4cc7ba0f43c 100644 --- a/paddle/fluid/operators/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise_div_op.cu @@ -14,19 +14,24 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_div_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel); + plat::float16>); diff --git a/paddle/fluid/operators/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise_mul_op.cu index 2fb1b4bee68..350d43168de 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise_mul_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_mul_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index bc3e95e904f..7223a972d23 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -350,7 +350,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - T val = 0; + T val(0); do { int x_offset = i * w + j; @@ -418,7 +418,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( int tid = threadIdx.x; int j = blockIdx.x; - T val = 0; + T val(0); int ttid = tid; while (true) { diff --git a/paddle/fluid/operators/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise_sub_op.cu index 8709f686f9a..ff3f6f8a2cb 100644 --- a/paddle/fluid/operators/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise_sub_op.cu @@ -14,19 +14,25 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise_sub_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( elementwise_sub, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 130f18dde4f..862249269ea 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -12,48 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -class FillConstantInferShape : public framework::InferShapeBase { +class FillConstantOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); - auto &shape = ctx->Attrs().Get>("shape"); + auto& shape = ctx->Attrs().Get>("shape"); ctx->SetOutputDim("Out", framework::make_ddim(shape)); } -}; - -class FillConstantOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - auto data_type = - static_cast(Attr("dtype")); - auto value = Attr("value"); - auto force_cpu = Attr("force_cpu"); - auto &out = - *scope.FindVar(Output("Out"))->GetMutable(); - out.Resize(framework::make_ddim(Attr>("shape"))); - if (force_cpu) { - auto cpu = platform::CPUPlace(); - out.mutable_data(cpu, framework::ToTypeIndex(data_type)); - } else { - out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); - } - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - math::set_constant(dev_ctx, &out, value); + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + static_cast(ctx.Attr("dtype")), + ctx.device_context()); } }; @@ -87,6 +67,11 @@ Fill up a variable with specified constant value. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, - ops::FillConstantInferShape, ops::FillConstantOpMaker, +REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker, paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc new file mode 100644 index 00000000000..51ccaefa433 --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fill_constant, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel, + ops::FillConstantOpKernel) diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h new file mode 100644 index 00000000000..b2a2a7b2fae --- /dev/null +++ b/paddle/fluid/operators/fill_constant_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + auto value = ctx.Attr("value"); + auto force_cpu = ctx.Attr("force_cpu"); + auto* out = ctx.Output("Out"); + out->Resize(framework::make_ddim(ctx.Attr>("shape"))); + if (force_cpu) { + auto cpu = platform::CPUPlace(); + out->mutable_data(cpu, framework::ToTypeIndex(data_type)); + } else { + out->mutable_data(ctx.GetPlace(), framework::ToTypeIndex(data_type)); + } + + math::set_constant(ctx.template device_context(), out, + value); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index 925dc19061e..352a17c927b 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -69,7 +70,6 @@ class FillOp : public framework::OperatorBase { framework::VisitDataType( dtype, FillOpVisitor(&tensor, Attr>("value"))); - if (!force_cpu && platform::is_gpu_place(place)) { // Copy tensor to out platform::DeviceContextPool &pool = diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 7784856417e..b4907237954 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -60,6 +61,7 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(gaussian_random, paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 0de58d5fddd..58b85abf822 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -15,11 +15,25 @@ limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { +template +HOSTDEVICE T log(const T& val) { + return std::log(val); +} + +template <> +HOSTDEVICE platform::float16 log(const platform::float16& val) { + // strage bug, hlog is not exists. + return static_cast(0); + // half tmp = static_cast(val); + // return static_cast(hlog(tmp)); +} + namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, @@ -35,12 +49,12 @@ template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, const int class_num) { int tid = threadIdx.x; - T val = 0; + T val(0); int idx = blockIdx.x * class_num + tid; int end = blockIdx.x * class_num + class_num; for (; idx < end; idx += blockDim.x) { - val += math::TolerableValue()(std::log(X[idx])) * label[idx]; + val += math::TolerableValue()(log(X[idx])) * label[idx]; } val = paddle::platform::reduceSum(val, tid, blockDim.x); @@ -84,6 +98,8 @@ class CrossEntropyFunctor { template class CrossEntropyFunctor; template class CrossEntropyFunctor; +template class CrossEntropyFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index adc5b3fe47c..2e4e4781c2e 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -33,6 +35,21 @@ struct TolerableValue { } }; +// float16 value clip behave different. +using paddle::platform::float16; +using paddle::platform::isfinite; +template <> +struct TolerableValue { + HOSTDEVICE float16 operator()(const float16& x) const { + if (isfinite(x)) + return x; + else if (x > static_cast(0)) + return std::numeric_limits::max(); + else + return std::numeric_limits::min(); + } +}; + template class CrossEntropyFunctor { public: diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a92762c7fea..00dbfc11a23 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -76,6 +77,7 @@ struct SelectedRowsAdd { template struct SelectedRowsAdd; template struct SelectedRowsAdd; +template struct SelectedRowsAdd; namespace { template @@ -120,7 +122,7 @@ struct SelectedRowsAddTensor { auto* out_data = output->data(); SetConstant functor; - functor(context, output, 0.0); + functor(context, output, static_cast(0)); const int block_size = 256; dim3 threads(block_size, 1); @@ -138,6 +140,8 @@ struct SelectedRowsAddTensor { template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; template struct SelectedRowsAddTo { @@ -177,6 +181,8 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template struct SelectedRowsAddTo; namespace { template @@ -229,6 +235,8 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { @@ -276,7 +284,7 @@ struct MergeAdd { context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + constant_functor(context, out.mutable_value(), static_cast(0)); auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); @@ -300,6 +308,7 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 3effe776258..785c4baecbf 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -94,12 +94,15 @@ void SoftmaxGradCUDNNFunctor::operator()( template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor; +template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 91e0ab28efc..07aa23754f9 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( mean, ops::MeanKernel, - ops::MeanKernel); + ops::MeanKernel, + ops::MeanKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanGradKernel, - ops::MeanGradKernel); + ops::MeanGradKernel, + ops::MeanGradKernel); diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h index 362e9f9ae8b..a41d50ae0b9 100644 --- a/paddle/fluid/operators/mean_op.h +++ b/paddle/fluid/operators/mean_op.h @@ -55,7 +55,7 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = static_cast(IG->numel()); - Eigen::DSizes bcast(ig_size); + Eigen::DSizes bcast(static_cast(ig_size)); EigenVector::Flatten(*IG).device( *context.template device_context().eigen_device()) = diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 81f3e42bf41..6c5a83c6a50 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -20,6 +20,7 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, ops::MulKernel, ops::MulKernel); -REGISTER_OP_CUDA_KERNEL(mul_grad, - ops::MulGradKernel, - ops::MulGradKernel); +REGISTER_OP_CUDA_KERNEL( + mul_grad, ops::MulGradKernel, + ops::MulGradKernel, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 31f083565fd..9fdbee818a2 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -174,7 +174,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, @@ -182,4 +183,5 @@ REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel); REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, - ops::PoolCUDNNGradOpKernel); + ops::PoolCUDNNGradOpKernel, + ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/scale_op.cu b/paddle/fluid/operators/scale_op.cu index 04c802da129..d2668670463 100644 --- a/paddle/fluid/operators/scale_op.cu +++ b/paddle/fluid/operators/scale_op.cu @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" +#include "paddle/fluid/platform/float16.h" +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( scale, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel, paddle::operators::ScaleKernel); + int64_t>, + paddle::operators::ScaleKernel); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e9996..c2d45c3d2ef 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -78,4 +78,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 5fb4f011d9b..19359b7eef5 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SoftmaxKernel); REGISTER_OP_CUDA_KERNEL( softmax_grad, ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index 89bcd1bbc86..db4c2d6c115 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -11,10 +11,13 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( sum, ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 49a4afb3a8a..dda6772796c 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -46,7 +46,7 @@ class SumKernel : public framework::OpKernel { if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, - 0.0); + static_cast(0)); } math::SelectedRowsAddToTensor functor; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2d..5fc0784f665 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -11,16 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using paddle::platform::float16; template struct Pair { @@ -32,6 +35,11 @@ struct Pair { id = id; } + __device__ __forceinline__ void clear() { + v = -INFINITY; + id = -1; + } + __device__ __forceinline__ void operator=(const Pair& in) { v = in.v; id = in.id; @@ -53,6 +61,12 @@ struct Pair { int64_t id; }; +template <> +__device__ __forceinline__ void Pair::clear() { + v = platform::raw_uint16_to_float16(0x400); + id = -1; +} + template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { @@ -150,7 +164,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } } if (!(*is_empty)) { @@ -160,7 +174,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, } *max = topk[MaxLength - 1]; - if ((*max).v == -1) *is_empty = true; + if ((*max).v == static_cast(-1)) *is_empty = true; *beam = 0; } } @@ -181,7 +195,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - *beam) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].set(std::numeric_limits::min(), -1); } } if (!(*is_empty)) { @@ -273,7 +287,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, bool firststep = true; for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + topk[k].clear(); } while (k) { ThreadGetTopK(topk, &beam, k, @@ -325,5 +339,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index e1c7323a302..2b8039a0c1b 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -11,10 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -36,6 +40,11 @@ struct UniformGenerator { } }; +template +struct CastFunctor { + HOSTDEVICE V operator()(const T& a) { return static_cast(a); } +}; + // It seems that Eigen::Tensor::random in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -66,18 +75,50 @@ class GPUUniformRandomKernel : public framework::OpKernel { T max = static_cast(context.Attr("max")); thrust::counting_iterator index_sequence_begin(0); int64_t size = tensor->numel(); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); + if (out_var->IsType() && + std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16))) { + framework::Tensor master_copy_tensor; + master_copy_tensor.Resize(tensor->dims()); + float* master_copy_tensor_data = + master_copy_tensor.mutable_data(context.GetPlace()); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(master_copy_tensor_data), + UniformGenerator(static_cast(min), + static_cast(max), seed)); + platform::Transform trans; + auto* in_begin = master_copy_tensor.data(); + auto* in_end = in_begin + master_copy_tensor.numel(); + auto* out_begin = tensor->mutable_data(context.GetPlace()); + trans(context.template device_context(), + in_begin, in_end, out_begin, CastFunctor()); + } else { + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } + if (VLOG_IS_ON(5)) { + framework::Tensor cpu_tensor; + framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + auto& dev_ctx = + *platform::DeviceContextPool::Instance().Get(context.GetPlace()); + dev_ctx.Wait(); + auto x = framework::EigenVector::Flatten(cpu_tensor); + VLOG(5) << "The Uniform output " << x; + } } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + uniform_random, paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL( + uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); -- GitLab