From 1e818158f54f3d071945292853a5e4bce0536f04 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 4 Jun 2020 11:12:59 +0800 Subject: [PATCH] Feature/add amp_checkout_finite_and_scale op (#24875) * add amp_check_finite_and_scale op, test=develop * add cpu kernel, test=develop * use bool, test=develop * follow comments, test=develop --- paddle/fluid/operators/CMakeLists.txt | 2 + paddle/fluid/operators/amp/CMakeLists.txt | 2 + .../amp/amp_check_finite_and_scale_op.cc | 103 ++++++++++++++++++ .../amp/amp_check_finite_and_scale_op.cu | 75 +++++++++++++ .../amp/amp_check_finite_and_scale_op.h | 66 +++++++++++ paddle/fluid/pybind/op_function_generator.cc | 1 + .../test_amp_check_finite_and_scale_op.py | 88 +++++++++++++++ .../white_list/no_check_set_white_list.py | 1 + 8 files changed, 338 insertions(+) create mode 100644 paddle/fluid/operators/amp/CMakeLists.txt create mode 100644 paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc create mode 100644 paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu create mode 100644 paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9adc6ce8da..d5f1b528da 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -23,6 +23,8 @@ if(WITH_DISTRIBUTE) add_subdirectory(collective) endif() +add_subdirectory(amp) + add_subdirectory(reader) if (NOT WIN32) diff --git a/paddle/fluid/operators/amp/CMakeLists.txt b/paddle/fluid/operators/amp/CMakeLists.txt new file mode 100644 index 0000000000..5d468316e8 --- /dev/null +++ b/paddle/fluid/operators/amp/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc new file mode 100644 index 0000000000..01b6ccedcd --- /dev/null +++ b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class AmpCheckFiniteAndScaleOp : public framework::OperatorWithKernel { + public: + AmpCheckFiniteAndScaleOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", + "amp_check_finite_and_unscale"); + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", + "amp_check_finite_and_unscale"); + PADDLE_ENFORCE_EQ( + ctx->Inputs("X").size(), ctx->Outputs("Out").size(), + platform::errors::InvalidArgument( + "The input(X) and output(Out) should have same size in " + "Operator(amp_check_finite_and_unscale), size of input(X) is %d " + "and size of output(Out) is %d.", + ctx->Inputs("X").size(), ctx->Outputs("Out").size())); + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->SetOutputDim("FoundInfinite", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class AmpCheckFiniteAndScaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensors) The input tensors of amp_check_finite_and_scale operator.") + .AsDuplicable(); + AddInput("Scale", + "(Tensor) 1-dim tensor, the scale of amp_check_finite_and_scale " + "operator."); + AddOutput("Out", + "(Tensors) The scaled output tensor of " + "amp_check_finite_and_unscale operator.") + .AsDuplicable(); + AddOutput("FoundInfinite", + "(Tensor) 1-dim tensor, contains a int scalar, which indicates " + "if there there is infinite or nan item in input X."); + AddComment(R"DOC( +amp_check_finite_and_scale operator. +Check if input X contains all finite data, if yes, scale it by input Scale. + +$$Out = X * scale$$ + +If any tensor in X contains Inf or Nan, the Out will generate a indicator. +FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of +Out should not be used, and its data may not be deterministic. +Otherwise, FoundInfinite will be 0 (False). + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + amp_check_finite_and_scale, ops::AmpCheckFiniteAndScaleOp, + ops::AmpCheckFiniteAndScaleOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + amp_check_finite_and_scale, + ops::AmpCheckFiniteAndScaleKernel, + ops::AmpCheckFiniteAndScaleKernel); diff --git a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu new file mode 100644 index 0000000000..e691dd8bc8 --- /dev/null +++ b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.cu @@ -0,0 +1,75 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +__global__ void AmpCheckFiniteAndScale(const T* in, const T* scale, int num, + int* found_inf, T* out) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx < num) { + if (!std::isfinite(in[idx])) { + *found_inf = 1; + } + out[idx] = *found_inf ? in[idx] : in[idx] * scale[0]; + } +} + +template +class AmpCheckFiniteAndScaleKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + const auto xs = ctx.MultiInput("X"); + const auto* scale = ctx.Input("Scale"); + auto outs = ctx.MultiOutput("Out"); + auto* found_inf = ctx.Output("FoundInfinite"); + + const T* scale_data = scale->data(); + int* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); + cudaMemset(found_inf_data, false, found_inf->numel() * sizeof(bool)); + + for (size_t i = 0; i < xs.size(); ++i) { + const auto* x = xs[i]; + auto* out = outs[i]; + const T* x_data = x->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + + int num = x->numel(); + int block = 512; + int grid = (num + block - 1) / block; + VLOG(3) << "launch kernel"; + AmpCheckFiniteAndScale<<>>( + x_data, scale_data, num, found_inf_data, out_data); + VLOG(3) << "finish kernel"; + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + amp_check_finite_and_scale, + ops::AmpCheckFiniteAndScaleKernel, + ops::AmpCheckFiniteAndScaleKernel); diff --git a/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h new file mode 100644 index 0000000000..6c2c4eb8a6 --- /dev/null +++ b/paddle/fluid/operators/amp/amp_check_finite_and_scale_op.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/isfinite_op.h" + +namespace paddle { +namespace operators { + +template +class AmpCheckFiniteAndScaleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + const auto xs = ctx.MultiInput("X"); + const auto* scale = ctx.Input("Scale"); + auto outs = ctx.MultiOutput("Out"); + auto* found_inf = ctx.Output("FoundInfinite"); + + const T* scale_data = scale->data(); + bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); + + *found_inf_data = false; + framework::Tensor is_finite = + ctx.AllocateTmpTensor({1}, dev_ctx); + bool* is_finite_data = is_finite.template data(); + + auto& dev = *ctx.template device_context().eigen_device(); + for (size_t i = 0; i < xs.size(); ++i) { + const auto* x = xs[i]; + auto* out = outs[i]; + out->mutable_data(dev_ctx.GetPlace()); + if (!(*found_inf_data)) { + framework::TensorIsfinite(*x, &is_finite); + if (*is_finite_data) { + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*x); + eigen_out.device(dev) = (*scale_data) * eigen_in; + } else { + *found_inf_data = true; + break; + } + } + } + return; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index e1a525f464..fc60b6302e 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -76,6 +76,7 @@ std::map> op_passing_outs_map = { {"matmul", {"Out"}}, {"fake_quantize_dequantize_moving_average_abs_max", {"Out", "OutScale", "OutAccum", "OutState"}}, + {"amp_check_finite_and_scale", {"Out", "FoundInfinite"}}, }; // clang-format off diff --git a/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py b/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py new file mode 100644 index 0000000000..70863d3857 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_amp_check_finite_and_scale_op.py @@ -0,0 +1,88 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid + + +class TestAmpCheckFiniteAndScaleOp(OpTest): + def setUp(self): + self.op_type = "amp_check_finite_and_scale" + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([0]), + 'Out': [('out0', x * scale)], + } + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output() + + +class TestAmpCheckFiniteAndScaleOpWithNan(OpTest): + def setUp(self): + self.op_type = "amp_check_finite_and_scale" + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + x[128][128] = np.nan + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x)], + } + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + # When input contains nan, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output(no_check_set=['Out']) + + +class TestAmpCheckFiniteAndScaleOpWithInf(OpTest): + def setUp(self): + self.op_type = "amp_check_finite_and_scale" + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + x[128][128] = np.inf + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x)], + } + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + # When input contains inf, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output(no_check_set=['Out']) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index e58c3834e3..816e7c6ea0 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -23,4 +23,5 @@ no_check_set_white_list = [ 'unsqueeze2', 'cross_entropy2', 'seed', + 'amp_check_finite_and_scale', ] -- GitLab