diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..237cfcc6f1172518097863158ca6dbd595af4186 --- /dev/null +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_mlu.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class CheckFiniteAndUnscaleMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + const auto xs = ctx.MultiInput("X"); + const auto* scale = ctx.Input("Scale"); + auto outs = ctx.MultiOutput("Out"); + auto* found_inf = ctx.Output("FoundInfinite"); + + found_inf->mutable_data(dev_ctx.GetPlace()); + + MLUCnnlTensorDesc scale_desc(*scale); + MLUCnnlTensorDesc found_inf_desc(*found_inf, CNNL_LAYOUT_ARRAY, + ToCnnlDataType()); + + for (size_t i = 0; i < xs.size(); ++i) { + const auto* x = xs[i]; + auto* out = outs[i]; + out->mutable_data(ctx.GetPlace()); + + // check is_finite or is_nan + Tensor is_finite(found_inf->type()); + if (i != 0) { + is_finite.Resize(phi::make_ddim({1})); + is_finite.mutable_data(ctx.GetPlace()); + } else { + is_finite.ShareDataWith(*found_inf); + } + + MLUCnnlTensorDesc x_desc(*x); + + MLUCnnl::IsNanInf(ctx, x_desc.get(), GetBasePtr(x), + GetBasePtr(&is_finite)); + + // save is_finite by logical_and op after checking every input + if (i != 0) { + MLUCnnlTensorDesc is_finite_desc(is_finite, CNNL_LAYOUT_ARRAY, + ToCnnlDataType()); + MLUCnnl::Logic(ctx, CNNL_LOGIC_OP_OR, found_inf_desc.get(), + GetBasePtr(found_inf), is_finite_desc.get(), + GetBasePtr(&is_finite), found_inf_desc.get(), + GetBasePtr(found_inf)); + } + + // The normal logic is : + // out = in, if found_inf = true + // out = in/scale, if found_inf = false + // But when found_inf is true, the data of Out should not be used. + // So, on MLU, we always compute out with in/scale. + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, x_desc.get(), + GetBasePtr(x), scale_desc.get(), GetBasePtr(scale), + out_desc.get(), GetBasePtr(out)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_MLU_KERNEL(check_finite_and_unscale, + ops::CheckFiniteAndUnscaleMLUKernel, + ops::CheckFiniteAndUnscaleMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_amp_check_finite_and_scale_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_amp_check_finite_and_scale_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..57fa56acd687582fa67c1592a7d5c505ca6cce06 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_amp_check_finite_and_scale_op_mlu.py @@ -0,0 +1,145 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle + +paddle.enable_static() +SEED = 2022 + + +class TestCheckFiniteAndUnscaleOp(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "check_finite_and_unscale" + self.init_dtype() + self.init_test_case() + + def init_test_case(self): + x = np.random.random((129, 129)).astype(self.dtype) + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([0]), + 'Out': [('out0', x / scale)], + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestCheckFiniteAndUnscaleOpWithNan(TestCheckFiniteAndUnscaleOp): + def init_test_case(self): + x = np.random.random((129, 129)).astype(self.dtype) + x[128][128] = np.nan + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x)], + } + + def test_check_output(self): + # When input contains nan, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output_with_place(self.place, no_check_set=['Out']) + + +class TestCheckFiniteAndUnscaleOpWithInf(TestCheckFiniteAndUnscaleOp): + def init_test_case(self): + x = np.random.random((129, 129)).astype(self.dtype) + x[128][128] = np.inf + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x)], + } + + def test_check_output(self): + # When input contains inf, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output_with_place(self.place, no_check_set=['Out']) + + +class TestCheckFiniteAndUnscaleOpMultiInput(TestCheckFiniteAndUnscaleOp): + def init_test_case(self): + x0 = np.random.random((129, 129)).astype(self.dtype) + x1 = np.random.random((129, 129)).astype(self.dtype) + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x0), ('x1', x1)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([0]), + 'Out': [('out0', x0 / scale), ('out1', x1 / scale)], + } + + +class TestCheckFiniteAndUnscaleOpMultiInputWithNan(TestCheckFiniteAndUnscaleOp): + def init_test_case(self): + x0 = np.random.random((129, 129)).astype(self.dtype) + x0[128][128] = np.nan + x1 = np.random.random((129, 129)).astype(self.dtype) + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x0), ('x1', x1)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x0 / scale), ('out1', x1 / scale)], + } + + def test_check_output(self): + # When input contains inf, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output_with_place(self.place, no_check_set=['Out']) + + +class TestCheckFiniteAndUnscaleOpMultiInputWithInf(TestCheckFiniteAndUnscaleOp): + def init_test_case(self): + x0 = np.random.random((129, 129)).astype(self.dtype) + x0[128][128] = np.nan + x1 = np.random.random((129, 129)).astype(self.dtype) + x1[128][128] = np.inf + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x0), ('x1', x1)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x0 / scale), ('out1', x1 / scale)], + } + + def test_check_output(self): + # When input contains inf, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output_with_place(self.place, no_check_set=['Out']) + + +if __name__ == '__main__': + unittest.main()