// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template HOSTDEVICE void Update(const bool* found_inf_data, const T* pre_loss_scaling_data, const int* good_in_data, const int* bad_in_data, const int incr_every_n_steps, const int decr_every_n_nan_or_inf, const float incr_ratio, const float decr_ratio, T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) { if (*found_inf_data) { *good_out_data = 0; *bad_out_data = *bad_in_data + 1; if (*bad_out_data == decr_every_n_nan_or_inf) { T new_loss_scaling = *pre_loss_scaling_data * decr_ratio; *updated_loss_scaling_data = new_loss_scaling < static_cast(1) ? static_cast(1) : new_loss_scaling; *bad_out_data = 0; } } else { *bad_out_data = 0; *good_out_data = *good_in_data + 1; if (*good_out_data == incr_every_n_steps) { T new_loss_scaling = *pre_loss_scaling_data * incr_ratio; *updated_loss_scaling_data = std::isfinite(new_loss_scaling) ? new_loss_scaling : *pre_loss_scaling_data; *good_out_data = 0; } } } template class UpdateLossScalingFunctor { public: void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data, const T* pre_loss_scaling_data, const int* good_in_data, const int* bad_in_data, const int incr_every_n_steps, const int decr_every_n_nan_or_inf, const float incr_ratio, const float decr_ratio, T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) const; }; template class LazyZeroInputs { public: void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data, const std::vector& xs, const std::vector& outs) const; }; template class UpdateLossScalingKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto xs = ctx.MultiInput("X"); const auto* found_inf = ctx.Input("FoundInfinite"); const auto* pre_loss_scaling = ctx.Input("PrevLossScaling"); const auto* good_in = ctx.Input("InGoodSteps"); const auto* bad_in = ctx.Input("InBadSteps"); auto outs = ctx.MultiOutput("Out"); auto* updated_loss_scaling = ctx.Output("LossScaling"); auto* good_out = ctx.Output("OutGoodSteps"); auto* bad_out = ctx.Output("OutBadSteps"); PADDLE_ENFORCE_EQ(found_inf->numel(), 1, platform::errors::InvalidArgument( "FoundInfinite must has only one element.")); const bool* found_inf_data = found_inf->data(); const T* pre_loss_scaling_data = pre_loss_scaling->data(); const int* good_in_data = good_in->data(); const int* bad_in_data = bad_in->data(); auto& dev_ctx = ctx.template device_context(); T* updated_loss_scaling_data = updated_loss_scaling->mutable_data(dev_ctx.GetPlace()); int* good_out_data = good_out->mutable_data(dev_ctx.GetPlace()); int* bad_out_data = bad_out->mutable_data(dev_ctx.GetPlace()); const int incr_every_n_steps = ctx.Attr("incr_every_n_steps"); const int decr_every_n_nan_or_inf = ctx.Attr("decr_every_n_nan_or_inf"); const float incr_ratio = ctx.Attr("incr_ratio"); const float decr_ratio = ctx.Attr("decr_ratio"); UpdateLossScalingFunctor{}( dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data); LazyZeroInputs{}(dev_ctx, found_inf_data, xs, outs); } }; } // namespace operators } // namespace paddle