diff --git a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc deleted file mode 100644 index 7f0ca1493f712f7f4809a56bf6a23f8757f94c2d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h" - -#include -#include - -namespace paddle { -namespace operators { - -class AmpCheckFiniteAndScaleOp : public framework::OperatorWithKernel { - public: - AmpCheckFiniteAndScaleOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", - "amp_check_finite_and_unscale"); - OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", - "amp_check_finite_and_unscale"); - PADDLE_ENFORCE_EQ( - ctx->Inputs("X").size(), ctx->Outputs("Out").size(), - platform::errors::InvalidArgument( - "The input(X) and output(Out) should have same size in " - "Operator(amp_check_finite_and_unscale), size of input(X) is %d " - "and size of output(Out) is %d.", - ctx->Inputs("X").size(), ctx->Outputs("Out").size())); - auto x_dims = ctx->GetInputsDim("X"); - ctx->SetOutputsDim("Out", x_dims); - ctx->SetOutputDim("FoundInfinite", {1}); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -class AmpCheckFiniteAndScaleOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(Tensors) The input tensors of amp_check_finite_and_scale operator.") - .AsDuplicable(); - AddInput("Scale", - "(Tensor) 1-dim tensor, the scale of amp_check_finite_and_scale " - "operator."); - AddOutput("Out", - "(Tensors) The scaled output tensor of " - "amp_check_finite_and_unscale operator.") - .AsDuplicable(); - AddOutput("FoundInfinite", - "(Tensor) 1-dim tensor, contains a bool scalar, which indicates " - "if there there is infinite or nan item in input X."); - AddComment(R"DOC( -amp_check_finite_and_scale operator. -Check if input X contains all finite data, if yes, scale it by input Scale. - -$$Out = X * scale$$ - -If any tensor in X contains Inf or Nan, the Out will generate a indicator. -FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of -Out should not be used, and its data may not be deterministic. -Otherwise, FoundInfinite will be 0 (False). - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR( - amp_check_finite_and_scale, ops::AmpCheckFiniteAndScaleOp, - ops::AmpCheckFiniteAndScaleOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL( - amp_check_finite_and_scale, - ops::AmpCheckFiniteAndScaleKernel, - ops::AmpCheckFiniteAndScaleKernel); diff --git a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h deleted file mode 100644 index 6c2c4eb8a615c4c04a98601c25b5de43b4262e6b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/isfinite_op.h" - -namespace paddle { -namespace operators { - -template -class AmpCheckFiniteAndScaleKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - const auto xs = ctx.MultiInput("X"); - const auto* scale = ctx.Input("Scale"); - auto outs = ctx.MultiOutput("Out"); - auto* found_inf = ctx.Output("FoundInfinite"); - - const T* scale_data = scale->data(); - bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); - - *found_inf_data = false; - framework::Tensor is_finite = - ctx.AllocateTmpTensor({1}, dev_ctx); - bool* is_finite_data = is_finite.template data(); - - auto& dev = *ctx.template device_context().eigen_device(); - for (size_t i = 0; i < xs.size(); ++i) { - const auto* x = xs[i]; - auto* out = outs[i]; - out->mutable_data(dev_ctx.GetPlace()); - if (!(*found_inf_data)) { - framework::TensorIsfinite(*x, &is_finite); - if (*is_finite_data) { - auto eigen_out = framework::EigenVector::Flatten(*out); - auto eigen_in = framework::EigenVector::Flatten(*x); - eigen_out.device(dev) = (*scale_data) * eigen_in; - } else { - *found_inf_data = true; - break; - } - } - } - return; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..51c659d5db1c33d5e2db261b998a0673f5e766cb --- /dev/null +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" +#include "paddle/fluid/framework/tensor_util.h" + +namespace paddle { +namespace operators { + +class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel { + public: + CheckFiniteAndUnscaleOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", + "check_finite_and_unscale"); + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", + "check_finite_and_unscale"); + PADDLE_ENFORCE_EQ( + ctx->Inputs("X").size(), ctx->Outputs("Out").size(), + platform::errors::InvalidArgument( + "The input(X) and output(Out) should have same size in " + "Operator(check_finite_and_unscale), size of input(X) is %d " + "and size of output(Out) is %d.", + ctx->Inputs("X").size(), ctx->Outputs("Out").size())); + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->SetOutputDim("FoundInfinite", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class CheckFiniteAndUnscaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensors) The input tensors of check_finite_and_unscale operator.") + .AsDuplicable(); + AddInput("Scale", + "(Tensor) 1-dim tensor, the scale of check_finite_and_unscale " + "operator."); + AddOutput("Out", + "(Tensors) The scaled output tensor of " + "check_finite_and_unscale operator.") + .AsDuplicable(); + AddOutput("FoundInfinite", + "(Tensor) 1-dim tensor, contains a bool scalar, which indicates " + "if there there is infinite or nan item in input X."); + AddComment(R"DOC( +check_finite_and_unscale operator. +Check if input X contains all finite data, if yes, scale it by input Scale. + +$$Out = X / scale$$ + +If any tensor in X contains Inf or Nan, the Out will generate a indicator. +FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of +Out should not be used, and its data may not be deterministic. +Otherwise, FoundInfinite will be 0 (False). + +)DOC"); + } +}; + +template +class CheckFiniteAndUnscaleCpuKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + const auto xs = ctx.MultiInput("X"); + const auto* scale = ctx.Input("Scale"); + auto outs = ctx.MultiOutput("Out"); + auto* found_inf = ctx.Output("FoundInfinite"); + + const T* scale_data = scale->data(); + bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); + + *found_inf_data = false; + framework::Tensor is_finite = + ctx.AllocateTmpTensor({1}, dev_ctx); + bool* is_finite_data = is_finite.template data(); + + auto& dev = *ctx.template device_context() + .eigen_device(); + + T inverse_scale = Inverse(*scale_data); + for (size_t i = 0; i < xs.size(); ++i) { + const auto* x = xs[i]; + auto* out = outs[i]; + out->mutable_data(dev_ctx.GetPlace()); + if (!(*found_inf_data)) { + framework::TensorIsfinite(*x, &is_finite); + *found_inf_data = !(*is_finite_data); + } + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*x); + if (!(*found_inf_data)) { + eigen_out.device(dev) = eigen_in * inverse_scale; + } else { + eigen_out.device(dev) = eigen_in * static_cast(0); + } + } + return; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + check_finite_and_unscale, ops::CheckFiniteAndUnscaleOp, + ops::CheckFiniteAndUnscaleOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(check_finite_and_unscale, + ops::CheckFiniteAndUnscaleCpuKernel, + ops::CheckFiniteAndUnscaleCpuKernel); diff --git a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu similarity index 63% rename from paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu rename to paddle/fluid/operators/amp/check_finite_and_unscale_op.cu index ee00d7c5f4499867c2c706ddcf314c1bfae0a866..cf9df34a2467f8461c4c284b4848c54b76edf452 100644 --- a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu @@ -14,28 +14,31 @@ limitations under the License. */ #include -#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" namespace paddle { namespace operators { template -__global__ void AmpCheckFiniteAndScale(const T* in, const T* scale, int num, - bool* found_inf, T* out) { +__global__ void GpuInverse(const T* s, T* o) { + *o = Inverse(*s); +} + +template +__global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num, + bool* found_inf, T* out) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < num) { if (!isfinite(in[idx])) { - *found_inf = 1; + *found_inf = true; } - out[idx] = *found_inf ? in[idx] : in[idx] * scale[0]; + out[idx] = *found_inf ? in[idx] : in[idx] * (*scale); } } template -class AmpCheckFiniteAndScaleKernel - : public framework::OpKernel { +class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto& dev_ctx = ctx.template device_context(); @@ -48,6 +51,12 @@ class AmpCheckFiniteAndScaleKernel bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); cudaMemset(found_inf_data, false, found_inf->numel() * sizeof(bool)); + framework::Tensor inverse_scale = + ctx.AllocateTmpTensor({1}, dev_ctx); + T* inverse_scale_v = inverse_scale.template data(); + + GpuInverse<<<1, 1, 0, dev_ctx.stream()>>>(scale_data, inverse_scale_v); + for (size_t i = 0; i < xs.size(); ++i) { const auto* x = xs[i]; auto* out = outs[i]; @@ -55,11 +64,11 @@ class AmpCheckFiniteAndScaleKernel T* out_data = out->mutable_data(dev_ctx.GetPlace()); int num = x->numel(); - int block = 512; + int block = 1024; int grid = (num + block - 1) / block; VLOG(3) << "launch kernel"; - AmpCheckFiniteAndScale<<>>( - x_data, scale_data, num, found_inf_data, out_data); + CheckFiniteAndUnscale<<>>( + x_data, inverse_scale_v, num, found_inf_data, out_data); VLOG(3) << "finish kernel"; } } @@ -68,9 +77,6 @@ class AmpCheckFiniteAndScaleKernel } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - amp_check_finite_and_scale, - ops::AmpCheckFiniteAndScaleKernel, - ops::AmpCheckFiniteAndScaleKernel); +REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale, + ops::CheckFiniteAndUnscaleGpuKernel, + ops::CheckFiniteAndUnscaleGpuKernel); diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.h b/paddle/fluid/operators/amp/check_finite_and_unscale_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4fb8744d0eee3c58f2948c5a466e08c2700b4332 --- /dev/null +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/operators/isfinite_op.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +template +inline HOSTDEVICE T Inverse(T s) { + return 1.0 / s; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cc b/paddle/fluid/operators/amp/update_loss_scaling_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fca3c531b40550952273f03f41bbc62cbff170fc --- /dev/null +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cc @@ -0,0 +1,170 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/amp/update_loss_scaling_op.h" +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class UpdateLossScalingOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasInput("InGoodSteps"), "Input", "InGoodSteps", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps", + "update_loss_scaling"); + OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps", + "update_loss_scaling"); + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->SetOutputDim("LossScaling", {1}); + ctx->SetOutputDim("OutGoodSteps", {1}); + ctx->SetOutputDim("OutBadSteps", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"), + ctx.device_context()); + } +}; + +class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensors) The input tensors of update_loss_scaling operator.") + .AsDuplicable(); + AddInput("FoundInfinite", + "(Tensor) 1-dim tensor, contains a bool scalar, which indicates " + "whether there is any infinite gradient."); + AddInput("PrevLossScaling", + "(Tensor) 1-dim tensor, previous loss scaling."); + AddInput("InGoodSteps", + "(Tensor) 1-dim tensor, accumulates good steps in which all " + "gradients are finite."); + AddInput("InBadSteps", + "(Tensor) 1-dim tensor, accumulates bad steps in which some " + "gradients are infinite."); + AddOutput("Out", + "(Tensors) The output tensor of update_loss_scaling operator.") + .AsDuplicable(); + AddOutput("LossScaling", "(Tensor) 1-dim tensor, updated loss scaling."); + AddOutput("OutGoodSteps", "(Tensor) 1-dim tensor, pdated good steps."); + AddOutput("OutBadSteps", "(Tensor) 1-dim tensor, updated bad steps."); + AddAttr("incr_every_n_steps", + "A value represents increasing loss scaling every n " + "consecutive steps with finite gradients."); + AddAttr("decr_every_n_nan_or_inf", + "A value represents decreasing loss scaling every n " + "accumulated steps with nan or inf gradients."); + AddAttr("incr_ratio", + "The multiplier to use when increasing the loss scaling.") + .AddCustomChecker([](float incr_ratio) { + PADDLE_ENFORCE_EQ(incr_ratio > 1.0f, true, + platform::errors::InvalidArgument( + "'incr_ratio' should be greater than 1, but " + "the received is %f", + incr_ratio)); + }); + AddAttr( + "decr_ratio", + "The less-than-one-multiplier to use when decreasing loss scaling.") + .AddCustomChecker([](float decr_ratio) { + PADDLE_ENFORCE_EQ(decr_ratio > 0.0f && decr_ratio < 1.0f, true, + platform::errors::InvalidArgument( + "'incr_ratio' should be between 0 and 1, but " + "the received is %f", + decr_ratio)); + }); + AddComment(R"DOC( +Update loss scaling according to overall gradients. If all gradients is +finite after incr_every_n_steps, loss scaling will increase by incr_ratio. +Otherwise, loss scaling will decrease by decr_ratio after +decr_every_n_nan_or_inf steps and each step some gradients are infinite. + +)DOC"); + } +}; + +template +class UpdateLossScalingFunctor { + public: + void operator()(const platform::CPUDeviceContext& ctx, + const bool* found_inf_data, const T* pre_loss_scaling_data, + const int* good_in_data, const int* bad_in_data, + const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, const float incr_ratio, + const float decr_ratio, T* updated_loss_scaling_data, + int* good_out_data, int* bad_out_data) const { + Update(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, + decr_ratio, updated_loss_scaling_data, good_out_data, + bad_out_data); + } +}; + +template +class LazyZeroInputs { + public: + void operator()(const platform::CPUDeviceContext& dev_ctx, + const bool* found_inf_data, + const std::vector& xs, + const std::vector& outs) const { + if (*found_inf_data) { + VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --"; + for (size_t i = 0; i < xs.size(); ++i) { + auto* out = outs[i]; + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + int num = out->numel(); + std::memset(out_data, 0, num * sizeof(T)); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + update_loss_scaling, ops::UpdateLossScalingOp, + ops::UpdateLossScalingOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(update_loss_scaling, + ops::UpdateLossScalingKernel, + ops::UpdateLossScalingKernel); diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cu b/paddle/fluid/operators/amp/update_loss_scaling_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2bc60423d247447adf18eb3ef050ca9b395a2e2f --- /dev/null +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cu @@ -0,0 +1,84 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/update_loss_scaling_op.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +template +__global__ void GpuUpdateLossScaling( + const bool* found_inf_data, const T* pre_loss_scaling_data, + const int* good_in_data, const int* bad_in_data, + const int incr_every_n_steps, const int decr_every_n_nan_or_inf, + const float incr_ratio, const float decr_ratio, + T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) { + Update(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + updated_loss_scaling_data, good_out_data, bad_out_data); +} + +template +class UpdateLossScalingFunctor { + public: + void operator()(const platform::CUDADeviceContext& dev_ctx, + const bool* found_inf_data, const T* pre_loss_scaling_data, + const int* good_in_data, const int* bad_in_data, + const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, const float incr_ratio, + const float decr_ratio, T* updated_loss_scaling_data, + int* good_out_data, int* bad_out_data) const { + GpuUpdateLossScaling<<<1, 1, 0, dev_ctx.stream()>>>( + found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + updated_loss_scaling_data, good_out_data, bad_out_data); + } +}; + +template +class LazyZeroInputs { + public: + void operator()(const platform::CUDADeviceContext& dev_ctx, + const bool* found_inf_data, + const std::vector& xs, + const std::vector& outs) const { + const auto gpu_place = + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()); + bool has_inf{false}; + memory::Copy(platform::CPUPlace(), &has_inf, gpu_place, found_inf_data, + sizeof(bool), dev_ctx.stream()); + if (has_inf) { + VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --"; + for (size_t i = 0; i < xs.size(); ++i) { + auto* out = outs[i]; + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + int num = out->numel(); + cudaMemset(out_data, 0, num * sizeof(T)); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using GPU = paddle::platform::CUDADeviceContext; + +REGISTER_OP_CUDA_KERNEL(update_loss_scaling, + ops::UpdateLossScalingKernel, + ops::UpdateLossScalingKernel); diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.h b/paddle/fluid/operators/amp/update_loss_scaling_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ca23b72eff0e85ab94c4d1f11e986f69b4e2d776 --- /dev/null +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.h @@ -0,0 +1,123 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +HOSTDEVICE void Update(const bool* found_inf_data, + const T* pre_loss_scaling_data, const int* good_in_data, + const int* bad_in_data, const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, + const float incr_ratio, const float decr_ratio, + T* updated_loss_scaling_data, int* good_out_data, + int* bad_out_data) { + if (*found_inf_data) { + *good_out_data = 0; + *bad_out_data = *bad_in_data + 1; + if (*bad_out_data == decr_every_n_nan_or_inf) { + T new_loss_scaling = *pre_loss_scaling_data * decr_ratio; + *updated_loss_scaling_data = new_loss_scaling < static_cast(1) + ? static_cast(1) + : new_loss_scaling; + *bad_out_data = 0; + } + } else { + *bad_out_data = 0; + *good_out_data = *good_in_data + 1; + if (*good_out_data == incr_every_n_steps) { + T new_loss_scaling = *pre_loss_scaling_data * incr_ratio; + *updated_loss_scaling_data = std::isfinite(new_loss_scaling) + ? new_loss_scaling + : *pre_loss_scaling_data; + *good_out_data = 0; + } + } +} + +template +class UpdateLossScalingFunctor { + public: + void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data, + const T* pre_loss_scaling_data, const int* good_in_data, + const int* bad_in_data, const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, const float incr_ratio, + const float decr_ratio, T* updated_loss_scaling_data, + int* good_out_data, int* bad_out_data) const; +}; + +template +class LazyZeroInputs { + public: + void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data, + const std::vector& xs, + const std::vector& outs) const; +}; + +template +class UpdateLossScalingKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto xs = ctx.MultiInput("X"); + const auto* found_inf = ctx.Input("FoundInfinite"); + const auto* pre_loss_scaling = ctx.Input("PrevLossScaling"); + const auto* good_in = ctx.Input("InGoodSteps"); + const auto* bad_in = ctx.Input("InBadSteps"); + auto outs = ctx.MultiOutput("Out"); + auto* updated_loss_scaling = ctx.Output("LossScaling"); + auto* good_out = ctx.Output("OutGoodSteps"); + auto* bad_out = ctx.Output("OutBadSteps"); + + PADDLE_ENFORCE_EQ(found_inf->numel(), 1, + platform::errors::InvalidArgument( + "FoundInfinite must has only one element.")); + + const bool* found_inf_data = found_inf->data(); + const T* pre_loss_scaling_data = pre_loss_scaling->data(); + const int* good_in_data = good_in->data(); + const int* bad_in_data = bad_in->data(); + + auto& dev_ctx = ctx.template device_context(); + T* updated_loss_scaling_data = + updated_loss_scaling->mutable_data(dev_ctx.GetPlace()); + int* good_out_data = good_out->mutable_data(dev_ctx.GetPlace()); + int* bad_out_data = bad_out->mutable_data(dev_ctx.GetPlace()); + + const int incr_every_n_steps = ctx.Attr("incr_every_n_steps"); + const int decr_every_n_nan_or_inf = + ctx.Attr("decr_every_n_nan_or_inf"); + const float incr_ratio = ctx.Attr("incr_ratio"); + const float decr_ratio = ctx.Attr("decr_ratio"); + UpdateLossScalingFunctor{}( + dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data, + bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, + decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data); + LazyZeroInputs{}(dev_ctx, found_inf_data, xs, outs); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 178ecaff7e8d2e575cd64927fe4e39c773b2cb99..f751136640caad6acd3230bc22cd0e3f0fafe9fb 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -111,7 +111,9 @@ std::map> op_passing_outs_map = { {"fake_quantize_dequantize_moving_average_abs_max", {"Out", "OutScale", "OutAccum", "OutState"}}, {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}}, - {"amp_check_finite_and_scale", {"Out", "FoundInfinite"}}, + {"check_finite_and_unscale", {"Out", "FoundInfinite"}}, + {"update_loss_scaling", + {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}}, }; // clang-format off diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py new file mode 100644 index 0000000000000000000000000000000000000000..d4dc968ca0de44b01741bf1f1fbaac7a9a65287e --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -0,0 +1,124 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import Variable + +__all__ = ['check_finite_and_unscale', 'update_loss_scaling'] + + +def check_finite_and_unscale(x, scale, name=None): + """ + Check if input X contains all finite data, if yes, scale it by input Scale. + + $$Out = X / scale$$ + + If any tensor in X contains Inf or Nan, the Out will generate a indicator. + FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of + Out should not be used, and its data may not be deterministic. + Otherwise, FoundInfinite will be 0 (False). + Args: + x(list|tuple): The input tensors of check_finite_and_unscale operator. + scale: The scale of check_finite_and_unscale operator. + """ + check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') + for e in x: + check_variable_and_dtype(e, "x", ['float32', 'float64'], + 'check_finite_and_unscale') + + helper = LayerHelper("check_finite_and_unscale", **locals()) + found_inf = helper.create_variable_for_type_inference(dtype='bool') + + inputs = {'X': x, 'Scale': scale} + outputs = {'Out': x, 'FoundInfinite': found_inf} + helper.append_op( + type='check_finite_and_unscale', inputs=inputs, outputs=outputs) + + return x, found_inf + + +def update_loss_scaling(x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name=None): + """ + Update loss scaling according to overall gradients. If all gradients is + finite after incr_every_n_steps, loss scaling will increase by incr_ratio. + Otherwise, loss scaling will decrease by decr_ratio after + decr_every_n_nan_or_inf steps and each step some gradients are infinite. + + Args: + x(list|tuple): The input tensors of update_loss_scaling operator. + found_inf (Variable): A boolean variable indicates whether + there is any infinite gradient. + prev_loss_scaling (Variable): Previous loss scaling. + num_good_steps (Variable): A variable accumulates good steps in which + all gradients are finite. + num_bad_steps (Variable): A variable accumulates bad steps in which + some gradients are infinite. + incr_every_n_steps (int): A variable represents increasing loss + scaling every n consecutive steps with + finite gradients. + decr_every_n_nan_or_inf (int): A variable represents decreasing + loss scaling every n accumulated + steps with nan or inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + loss scaling. + """ + + check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", + ['float32', 'float64'], "update_loss_scaling") + check_type(x, 'x', (tuple, list), 'update_loss_scaling') + for e in x: + check_variable_and_dtype(e, "x", ['float32', 'float64'], + 'update_loss_scaling') + assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." + + helper = LayerHelper("update_loss_scaling", **locals()) + + inputs = { + 'X': x, + 'FoundInfinite': found_inf, + 'PrevLossScaling': prev_loss_scaling, + 'InGoodSteps': num_good_steps, + 'InBadSteps': num_bad_steps + } + + outputs = { + 'Out': x, + 'LossScaling': prev_loss_scaling, + 'OutGoodSteps': num_good_steps, + 'OutBadSteps': num_bad_steps + } + + attrs = { + 'incr_every_n_steps': incr_every_n_steps, + 'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf, + 'incr_ratio': incr_ratio, + 'decr_ratio': decr_ratio, + } + + helper.append_op( + type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs) + + return x diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index bfbd2700ae10bac4ad37462b5d7844b90dd05bbe..c9112ac849ce0506b7afd941b2213710e06bd1c6 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -17,9 +17,11 @@ from ... import default_startup_program from ... import layers from ... import unique_name from . import fp16_utils -from .fp16_utils import update_loss_scaling, rewrite_program +from .fp16_utils import rewrite_program from .fp16_utils import update_role_var_grad from .fp16_lists import AutoMixedPrecisionLists +from .amp_nn import check_finite_and_unscale +from .amp_nn import update_loss_scaling __all__ = ["decorate"] @@ -67,10 +69,8 @@ class OptimizerWithMixedPrecision(object): persistable=True) self._use_dynamic_loss_scaling = use_dynamic_loss_scaling if self._use_dynamic_loss_scaling: - self._incr_every_n_steps = layers.fill_constant( - shape=[1], dtype='int32', value=incr_every_n_steps) - self._decr_every_n_nan_or_inf = layers.fill_constant( - shape=[1], dtype='int32', value=decr_every_n_nan_or_inf) + self._incr_every_n_steps = incr_every_n_steps + self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf self._incr_ratio = incr_ratio self._decr_ratio = decr_ratio self._num_good_steps = layers.create_global_var( @@ -139,49 +139,46 @@ class OptimizerWithMixedPrecision(object): # Change the op_role_var attr for some ops, so that gradients # transferred across GPUs can be FP16. update_role_var_grad(self._train_program, self._params_grads) - scaled_params_grads = [] - for p, g in self._params_grads: - with self._train_program._optimized_guard([p, g]): - scaled_g = g / self._loss_scaling - scaled_params_grads.append([p, scaled_g]) - return scaled_params_grads + return self._params_grads - def apply_gradients(self, scaled_params_grads): + def apply_gradients(self, params_grads): """ Check scaled gradients to determine whether to update loss scaling and update parameters by their scaled gradients, Args: - scaled_params_grads (list): A list of params and scaled grads. + params_grads (list): A list of params and scaled grads. Returns: A list of optimize operators. """ - if self._use_dynamic_loss_scaling: + grads = [g for _, g in params_grads] + with self._train_program._optimized_guard(grads): + grads, found_inf = check_finite_and_unscale( + grads, self._loss_scaling, name="find_infinite_scale") - grads = [layers.reduce_sum(g) for [_, g] in scaled_params_grads] - all_grads = layers.concat(grads) - all_grads_sum = layers.reduce_sum(all_grads) - is_overall_finite = layers.isfinite(all_grads_sum) - - update_loss_scaling(is_overall_finite, self._loss_scaling, - self._num_good_steps, self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, self._incr_ratio, - self._decr_ratio) - - # apply_gradient append all ops in global block, thus we shouldn't - # apply gradient in the switch branch. - with layers.Switch() as switch: - with switch.case(is_overall_finite): - pass - with switch.default(): - for _, g in scaled_params_grads: - layers.assign(layers.zeros_like(g), g) - - optimize_ops = self._optimizer.apply_gradients(scaled_params_grads) + if self._use_dynamic_loss_scaling: + with self._train_program._optimized_guard(grads): + grads = update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling") + + params_unscaled_grads = [] + for pg, new_g in zip(params_grads, grads): + params_unscaled_grads.append((pg[0], new_g)) + # apply_gradient append all ops in global block, thus we shouldn't + # apply gradient in the switch branch. + optimize_ops = self._optimizer.apply_gradients(params_unscaled_grads) return optimize_ops diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 328dafe6219adb3c6355de0bafc430c52725024f..0b142ff33de55f36410eb9c23cb75210fc9d6321 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -328,77 +328,3 @@ def update_role_var_grad(main_prog, params_grads): raise ValueError("The op {0} is not in program".format(op)) block.desc._remove_op(op_idx, op_idx + 1) block._sync_with_cpp() - - -def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps, - num_bad_steps, incr_every_n_steps, - decr_every_n_nan_or_inf, incr_ratio, decr_ratio): - """ - Update loss scaling according to overall gradients. If all gradients is - finite after incr_every_n_steps, loss scaling will increase by incr_ratio. - Otherwise, loss scaling will decrease by decr_ratio after - decr_every_n_nan_or_inf steps and each step some gradients are infinite. - - Args: - is_overall_finite (Variable): A boolean variable indicates whether - all gradients are finite. - prev_loss_scaling (Variable): Previous loss scaling. - num_good_steps (Variable): A variable accumulates good steps in which - all gradients are finite. - num_bad_steps (Variable): A variable accumulates bad steps in which - some gradients are infinite. - incr_every_n_steps (Variable): A variable represents increasing loss - scaling every n consecutive steps with - finite gradients. - decr_every_n_nan_or_inf (Variable): A variable represents decreasing - loss scaling every n accumulated - steps with nan or inf gradients. - incr_ratio(float): The multiplier to use when increasing the loss - scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing - loss scaling. - """ - zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0) - with layers.Switch() as switch: - with switch.case(is_overall_finite): - should_incr_loss_scaling = layers.less_than(incr_every_n_steps, - num_good_steps + 1) - with layers.Switch() as switch1: - with switch1.case(should_incr_loss_scaling): - new_loss_scaling = prev_loss_scaling * incr_ratio - loss_scaling_is_finite = layers.isfinite(new_loss_scaling) - with layers.Switch() as switch2: - with switch2.case(loss_scaling_is_finite): - layers.assign(new_loss_scaling, prev_loss_scaling) - with switch2.default(): - pass - layers.assign(zero_steps, num_good_steps) - layers.assign(zero_steps, num_bad_steps) - - with switch1.default(): - layers.increment(num_good_steps) - layers.assign(zero_steps, num_bad_steps) - - with switch.default(): - should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf, - num_bad_steps + 1) - with layers.Switch() as switch3: - with switch3.case(should_decr_loss_scaling): - new_loss_scaling = prev_loss_scaling * decr_ratio - static_loss_scaling = \ - layers.fill_constant(shape=[1], - dtype='float32', - value=1.0) - less_than_one = layers.less_than(new_loss_scaling, - static_loss_scaling) - with layers.Switch() as switch4: - with switch4.case(less_than_one): - layers.assign(static_loss_scaling, - prev_loss_scaling) - with switch4.default(): - layers.assign(new_loss_scaling, prev_loss_scaling) - layers.assign(zero_steps, num_good_steps) - layers.assign(zero_steps, num_bad_steps) - with switch3.default(): - layers.assign(zero_steps, num_good_steps) - layers.increment(num_bad_steps) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 8f3ca9ec007ef5c1ab8769dde741a5d2b3697600..ff57f30dcd2ec73d55ff06e751767deea0a2eead 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -210,13 +210,12 @@ class AmpScaler(object): def _unscale(self, optimizer): if not self._enable: return - inv_scale = 1.0 / self._scale param_grads = [ param._grad_ivar() for param in optimizer._parameter_list if param._grad_ivar() is not None ] - core.ops.amp_check_finite_and_scale(param_grads, inv_scale, param_grads, - self._found_inf) + core.ops.check_finite_and_unscale(param_grads, self._scale, param_grads, + self._found_inf) def _update(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py b/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py index 70863d3857c43c84a583f0ccf7b9bd733fdb4fd0..fbacaa3d5ce10bdad6dd87fdfc04c1173aff18ff 100644 --- a/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py @@ -18,9 +18,9 @@ from op_test import OpTest, skip_check_grad_ci import paddle.fluid as fluid -class TestAmpCheckFiniteAndScaleOp(OpTest): +class TestCheckFiniteAndUnscaleOp(OpTest): def setUp(self): - self.op_type = "amp_check_finite_and_scale" + self.op_type = "check_finite_and_unscale" self.init_dtype() x = np.random.random((1024, 1024)).astype(self.dtype) scale = np.random.random((1)).astype(self.dtype) @@ -28,7 +28,7 @@ class TestAmpCheckFiniteAndScaleOp(OpTest): self.inputs = {'X': [('x0', x)], 'Scale': scale} self.outputs = { 'FoundInfinite': np.array([0]), - 'Out': [('out0', x * scale)], + 'Out': [('out0', x / scale)], } def init_dtype(self): @@ -38,9 +38,9 @@ class TestAmpCheckFiniteAndScaleOp(OpTest): self.check_output() -class TestAmpCheckFiniteAndScaleOpWithNan(OpTest): +class TestCheckFiniteAndUnscaleOpWithNan(OpTest): def setUp(self): - self.op_type = "amp_check_finite_and_scale" + self.op_type = "check_finite_and_unscale" self.init_dtype() x = np.random.random((1024, 1024)).astype(self.dtype) x[128][128] = np.nan @@ -61,9 +61,9 @@ class TestAmpCheckFiniteAndScaleOpWithNan(OpTest): self.check_output(no_check_set=['Out']) -class TestAmpCheckFiniteAndScaleOpWithInf(OpTest): +class TestCheckFiniteAndUnscaleOpWithInf(OpTest): def setUp(self): - self.op_type = "amp_check_finite_and_scale" + self.op_type = "check_finite_and_unscale" self.init_dtype() x = np.random.random((1024, 1024)).astype(self.dtype) x[128][128] = np.inf diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py index 38c3903306e6e76188cdb50476d6797814c434e9..73e014b35008ff5a0539c6a338755b9dc2cf68d4 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py @@ -57,7 +57,7 @@ class TestFleetAMPOptimizer(unittest.TestCase): ops = [op.type for op in avg_cost.block.ops] self.assertIn('cast', ops) - self.assertIn('isfinite', ops) + self.assertIn('check_finite_and_unscale', ops) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_update_loss_scaling_op.py b/python/paddle/fluid/tests/unittests/test_update_loss_scaling_op.py new file mode 100644 index 0000000000000000000000000000000000000000..fb93334415c3046362090a143f6c15069793709a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_update_loss_scaling_op.py @@ -0,0 +1,250 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid +import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn + + +class TestUpdateLossScalingOp(OpTest): + def setUp(self): + self.op_type = "update_loss_scaling" + self.init() + found_inf = np.array([False], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', np.zeros_like(x))], + 'LossScaling': self.prev_loss_scaling * self.incr_ratio, + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def init(self): + self.incr_ratio = 2.0 + self.decr_ratio = 0.8 + self.dtype = np.float32 + self.prev_loss_scaling = np.array([2048]).astype(self.dtype) + self.num_good_steps = np.array([999], dtype=np.int32) + self.num_bad_steps = np.array([1], dtype=np.int32) + self.zero_steps = np.array([0], dtype=np.int32) + self.attrs = { + 'incr_every_n_steps': 1000, + 'decr_every_n_nan_or_inf': 2, + 'incr_ratio': self.incr_ratio, + 'decr_ratio': self.decr_ratio, + } + + def test_check_output(self): + self.check_output(no_check_set=['Out']) + + +class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): + def setUp(self): + self.op_type = "update_loss_scaling" + self.init() + found_inf = np.array([True], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + x[i[0]][j[0]] = np.inf + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', np.zeros_like(x))], + 'LossScaling': self.prev_loss_scaling * self.decr_ratio, + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def test_check_output(self): + self.check_output() + + +class TestUpdateLossScalingLayer(unittest.TestCase): + def loss_scaling_check(self, use_cuda=True, scope=fluid.Scope()): + a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') + b = fluid.data(name="b", shape=[512, 128], dtype='float32') + x = [a, b] + found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') + prev_loss_scaling = fluid.data( + name="prev_loss_scaling", shape=[1], dtype='float32') + num_good_steps = fluid.data( + name="num_good_steps", shape=[1], dtype='int32') + num_bad_steps = fluid.data( + name="num_bad_steps", shape=[1], dtype='int32') + + a_v = np.random.random([1024, 1024]).astype('float32') + b_v = np.random.random([512, 128]).astype('float32') + found_inf_v = np.array([False]).astype('bool') + prev_loss_scaling_v = np.array([2048]).astype('float32') + num_good_steps_v = np.array([999], dtype=np.int32) + num_bad_steps_v = np.array([1], dtype=np.int32) + + incr_every_n_steps = 1000 + decr_every_n_nan_or_inf = 2 + incr_ratio = 2 + decr_ratio = 0.8 + + result = amp_nn.update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name="update_loss_scaling") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + with fluid.scope_guard(scope): + exe.run(fluid.default_startup_program()) + result_v = exe.run(feed={ + 'a': a_v, + 'b': b_v, + 'found_inf': found_inf_v, + 'prev_loss_scaling': prev_loss_scaling_v, + 'num_good_steps': num_good_steps_v, + 'num_bad_steps': num_bad_steps_v + }, + fetch_list=[ + result, x, found_inf, prev_loss_scaling, + num_good_steps, num_bad_steps + ]) + assert np.array_equal(result_v[0], a_v) + assert np.array_equal(result_v[1], b_v) + assert np.array_equal(result_v[0], result_v[2]) + assert np.array_equal(result_v[1], result_v[3]) + assert np.array_equal(result_v[4], found_inf_v) + assert np.array_equal(result_v[5], prev_loss_scaling_v * incr_ratio) + assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) + assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) + + def loss_scaling_check_inf(self, use_cuda=True, scope=fluid.Scope()): + a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') + b = fluid.data(name="b", shape=[512, 128], dtype='float32') + x = [a, b] + found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') + prev_loss_scaling = fluid.data( + name="prev_loss_scaling", shape=[1], dtype='float32') + num_good_steps = fluid.data( + name="num_good_steps", shape=[1], dtype='int32') + num_bad_steps = fluid.data( + name="num_bad_steps", shape=[1], dtype='int32') + + a_v = np.random.random([1024, 1024]).astype('float32') + b_v = np.random.random([512, 128]).astype('float32') + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + a_v[i[0]][j[0]] = np.inf + found_inf_v = np.array([True]).astype('bool') + prev_loss_scaling_v = np.array([2048]).astype('float32') + num_good_steps_v = np.array([999], dtype=np.int32) + num_bad_steps_v = np.array([1], dtype=np.int32) + + incr_every_n_steps = 1000 + decr_every_n_nan_or_inf = 2 + incr_ratio = 2 + decr_ratio = 0.8 + + result = amp_nn.update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name="update_loss_scaling") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + with fluid.scope_guard(scope): + exe.run(fluid.default_startup_program()) + result_v = exe.run(feed={ + 'a': a_v, + 'b': b_v, + 'found_inf': found_inf_v, + 'prev_loss_scaling': prev_loss_scaling_v, + 'num_good_steps': num_good_steps_v, + 'num_bad_steps': num_bad_steps_v + }, + fetch_list=[ + result, x, found_inf, prev_loss_scaling, + num_good_steps, num_bad_steps + ]) + assert np.array_equal(result_v[0], np.zeros_like(a_v)) + assert np.array_equal(result_v[1], np.zeros_like(b_v)) + assert np.array_equal(result_v[2], np.zeros_like(a_v)) + assert np.array_equal(result_v[3], np.zeros_like(b_v)) + assert np.array_equal(result_v[4], found_inf_v) + assert np.array_equal(result_v[5], prev_loss_scaling_v * decr_ratio) + assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) + assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) + + def test_loss_scaling_cpu(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check(use_cuda=False) + + def test_loss_scaling_cpu_inf(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check_inf(use_cuda=False) + + def test_loss_scaling_gpu(self): + if fluid.core.is_compiled_with_cuda(): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check(use_cuda=True) + + def test_loss_scaling_gpu_inf(self): + if fluid.core.is_compiled_with_cuda(): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check_inf(use_cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 0de0eeb464ad700abb2144e49a822582b8653589..afd3414943e9c94799aba5e5e747182623b0a095 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -25,6 +25,7 @@ no_check_set_white_list = [ 'unsqueeze2', 'cross_entropy2', 'seed', - 'amp_check_finite_and_scale', + 'check_finite_and_unscale', + 'update_loss_scaling', 'cudnn_lstm', ]