From 8d74782e72c7460d3c8ed2f49babb82d60abc61d Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 16 Sep 2020 06:17:08 +0000 Subject: [PATCH] Enable uniform_random_op and gaussian_random_op to support the float16 data type. --- paddle/fluid/operators/gaussian_random_op.cu | 34 ++++++++-- .../fluid/operators/optimizers/momentum_op.h | 56 ++++++++++------ paddle/fluid/operators/uniform_random_op.cu | 16 +++-- paddle/fluid/platform/float16.h | 4 ++ .../contrib/mixed_precision/fp16_utils.py | 65 ++++++++----------- 5 files changed, 105 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 7a0c93eb1b2..e2c0e55ef86 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -11,16 +11,31 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include #include +#include #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { +namespace details { +template +struct RandomDistributionType { + using Type = T; +}; + +template <> +struct RandomDistributionType { + using Type = float; +}; +} // namespace details + template struct GaussianGenerator { T mean_, std_; @@ -34,12 +49,16 @@ struct GaussianGenerator { : mean_(mean), std_(std), seed_(seed), offset_(offset) {} __host__ __device__ T operator()(const unsigned int n) const { + using DataType = typename details::RandomDistributionType::Type; + thrust::minstd_rand rng; rng.seed(seed_); - thrust::normal_distribution dist(mean_, std_); + thrust::normal_distribution dist(static_cast(mean_), + static_cast(std_)); unsigned int new_n = n + offset_; rng.discard(new_n); - return dist(rng); + T out = static_cast(dist(rng)); + return out; } }; @@ -122,10 +141,13 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(gaussian_random, - paddle::operators::GPUGaussianRandomKernel, - paddle::operators::GPUGaussianRandomKernel); +REGISTER_OP_CUDA_KERNEL( + gaussian_random, paddle::operators::GPUGaussianRandomKernel, + paddle::operators::GPUGaussianRandomKernel, + paddle::operators::GPUGaussianRandomKernel); REGISTER_OP_CUDA_KERNEL( gaussian_random_batch_size_like, paddle::operators::GPUGaussianRandomBatchSizeLikeKernel, - paddle::operators::GPUGaussianRandomBatchSizeLikeKernel); + paddle::operators::GPUGaussianRandomBatchSizeLikeKernel, + paddle::operators::GPUGaussianRandomBatchSizeLikeKernel< + paddle::platform::float16>); diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index cdcf4cec7ea..9e7094ba79c 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -19,11 +19,27 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { +namespace details { +template +struct LearningRateType { + using Type = T; +}; + +template <> +struct LearningRateType { + using Type = float; +}; +} // namespace details + +template +using DataType = typename details::LearningRateType::Type; + using framework::Tensor; using framework::SelectedRows; struct NoNesterov; @@ -124,7 +140,7 @@ class CPUDenseMomentumFunctor { auto p = framework::EigenVector::Flatten(*param); auto v = framework::EigenVector::Flatten(*velocity); auto g = framework::EigenVector::Flatten(*grad); - const float* lr = learning_rate->data(); + const auto* lr = learning_rate->data>(); v_out = v * mu + g; if (use_nesterov) { @@ -147,7 +163,7 @@ class DenseMomentumFunctor { const T* p_; const T* g_; const T* v_; - const float* lr_; + const DataType* lr_; const T mu_; const int64_t num_; T* p_out_; @@ -155,7 +171,7 @@ class DenseMomentumFunctor { public: DenseMomentumFunctor(const T* p, const T* g, const T* v, - const float* learning_rate, const T mu, + const DataType* learning_rate, const T mu, const int64_t num, T* p_out, T* v_out) : p_(p), g_(g), @@ -169,7 +185,7 @@ class DenseMomentumFunctor { // put memory access in register const T p = p_[i]; const T g = g_[i]; - const float lr = lr_[0]; + const auto lr = lr_[0]; const T v = v_[i]; T v_out = v * mu_ + g; T p_out = p - (g + v_out * mu_) * static_cast(lr); @@ -185,7 +201,7 @@ class DenseMomentumFunctor { const T* p_; const T* g_; const T* v_; - const float* lr_; + const DataType* lr_; const T mu_; const int64_t num_; T* p_out_; @@ -193,7 +209,7 @@ class DenseMomentumFunctor { public: DenseMomentumFunctor(const T* p, const T* g, const T* v, - const float* learning_rate, const T mu, + const DataType* learning_rate, const T mu, const int64_t num, T* p_out, T* v_out) : p_(p), g_(g), @@ -226,7 +242,7 @@ class SparseMomentumFunctor { const T* p_; const T* g_; const T* v_; - const float* lr_; + const DataType* lr_; const T mu_; const int64_t* rows_; const int64_t row_numel_; @@ -235,9 +251,10 @@ class SparseMomentumFunctor { T* v_out_; public: - SparseMomentumFunctor(const T* p, const T* g, const T* v, const float* lr, - const T mu, const int64_t* rows, int64_t row_numel, - int64_t row_height, T* p_out, T* v_out) + SparseMomentumFunctor(const T* p, const T* g, const T* v, + const DataType* lr, const T mu, const int64_t* rows, + int64_t row_numel, int64_t row_height, T* p_out, + T* v_out) : p_(p), g_(g), v_(v), @@ -256,7 +273,7 @@ class SparseMomentumFunctor { : static_cast(0); // put memory access in register const T p = p_[i]; - const float lr = lr_[0]; + const auto lr = lr_[0]; const T v = v_[i]; T v_out = v * mu_ + g; T p_out = p - (g + v_out * mu_) * static_cast(lr); @@ -272,7 +289,7 @@ class SparseMomentumFunctor { const T* p_; const T* g_; const T* v_; - const float* lr_; + const DataType* lr_; const T mu_; const int64_t* rows_; const int64_t row_numel_; @@ -281,9 +298,10 @@ class SparseMomentumFunctor { T* v_out_; public: - SparseMomentumFunctor(const T* p, const T* g, const T* v, const float* lr, - const T mu, const int64_t* rows, int64_t row_numel, - int64_t row_height, T* p_out, T* v_out) + SparseMomentumFunctor(const T* p, const T* g, const T* v, + const DataType* lr, const T mu, const int64_t* rows, + int64_t row_numel, int64_t row_height, T* p_out, + T* v_out) : p_(p), g_(g), v_(v), @@ -342,7 +360,7 @@ class MomentumOpKernel : public framework::OpKernel { if (use_nesterov) { DenseMomentumFunctor functor( param->data(), grad->data(), velocity->data(), - learning_rate->data(), mu, param->numel(), + learning_rate->data>(), mu, param->numel(), param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); @@ -350,7 +368,7 @@ class MomentumOpKernel : public framework::OpKernel { } else { DenseMomentumFunctor functor( param->data(), grad->data(), velocity->data(), - learning_rate->data(), mu, param->numel(), + learning_rate->data>(), mu, param->numel(), param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); @@ -382,7 +400,7 @@ class MomentumOpKernel : public framework::OpKernel { if (use_nesterov) { SparseMomentumFunctor functor( param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), mu, rows, + velocity->data(), learning_rate->data>(), mu, rows, row_numel, static_cast(merged_grad->rows().size()), param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); @@ -391,7 +409,7 @@ class MomentumOpKernel : public framework::OpKernel { } else { SparseMomentumFunctor functor( param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), mu, rows, + velocity->data(), learning_rate->data>(), mu, rows, row_numel, static_cast(merged_grad->rows().size()), param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 563a6c165b7..cec574a6ada 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -163,9 +164,12 @@ class GPUUniformRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL( + uniform_random, paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL( + uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 496eb78f20e..79291e2da88 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -131,6 +131,10 @@ struct PADDLE_ALIGN(2) float16 { #endif } + HOSTDEVICE inline float16(int32_t val) : float16(static_cast(val)) {} + + HOSTDEVICE inline float16(uint32_t val) : float16(static_cast(val)) {} + HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} template diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index a345949270b..f60e66a37cf 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -267,48 +267,35 @@ def cast_net_to_fp16(program): op._set_attr('dtype', core.VarDesc.VarType.FP16) -def cast_parameters_to_fp16(exe, program): +def cast_parameters_to_fp16(program): global_block = program.global_block() all_parameters = global_block.all_parameters() + is_bn_params = lambda param: (param.name.find('bn') != -1 and (param.name.endswith('_offset') or param.name.endswith('_mean') or param.name.endswith('_scale') or param.name.endswith('_variance'))) + all_param_names = {p.name for p in all_parameters if not is_bn_params(p)} + ops = global_block.ops + for param in all_parameters: - if not (param.name.find('bn') != -1 and - (param.name.endswith('_offset') or param.name.endswith('_mean') - or param.name.endswith('_scale') or - param.name.endswith('_variance'))): - param_t = global_scope().find_var(param.name).get_tensor() - data = np.array(param_t) - param_t.set(np.float16(data), exe.place) - - -# def cast_parameters_to_fp16(program): -# global_block = program.global_block() -# all_parameters = global_block.all_parameters() -# is_bn_params = lambda param: (param.name.find('bn') != -1 and (param.name.endswith('_offset') or param.name.endswith('_mean') or param.name.endswith('_scale') or param.name.endswith('_variance'))) -# all_param_names = {p.name for p in all_parameters if not is_bn_params(p)} -# ops = global_block.ops - -# for param in all_parameters: -# if param.name in all_param_names: -# param_var = global_block.var(param.name) -# if param_var.dtype == core.VarDesc.VarType.FP32: -# param_var.desc.set_dtype(core.VarDesc.VarType.FP16) - -# for op in ops: -# target_op = False -# for out_name in op.output_names: -# for out_var_name in op.output(out_name): -# if out_var_name in all_param_names: -# target_op = True -# if target_op: -# if op.has_attr('in_dtype') and op.attr( -# 'in_dtype') == core.VarDesc.VarType.FP32: -# op._set_attr('in_dtype', core.VarDesc.VarType.FP16) -# if op.has_attr('out_dtype') and op.attr( -# 'out_dtype') == core.VarDesc.VarType.FP32: -# op._set_attr('out_dtype', core.VarDesc.VarType.FP16) -# if op.has_attr('dtype') and op.attr( -# 'dtype') == core.VarDesc.VarType.FP32: -# op._set_attr('dtype', core.VarDesc.VarType.FP16) + if param.name in all_param_names: + param_var = global_block.var(param.name) + if param_var.dtype == core.VarDesc.VarType.FP32: + param_var.desc.set_dtype(core.VarDesc.VarType.FP16) + + for op in ops: + target_op = False + for out_name in op.output_names: + for out_var_name in op.output(out_name): + if out_var_name in all_param_names: + target_op = True + if target_op: + if op.has_attr('in_dtype') and op.attr( + 'in_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('in_dtype', core.VarDesc.VarType.FP16) + if op.has_attr('out_dtype') and op.attr( + 'out_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + if op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.FP16) def rewrite_program(main_prog, amp_lists): -- GitLab