未验证 提交 39a59dcf 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] refactor check_finite_and_scale npu kernel (#32407)

* refactor_check_finite_and_scale_npu_kernel

* fix compile

* add alloc_float_status op

* add alloc_float_status op

* add FloatStatus for check_finite_and_unscale

* refine code

* remove unneccessary logic

* refine for fleet
上级 a01b5109
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class AllocFloatStatusOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("FloatStatus"), "Output", "FloatStatus",
"alloc_float_status");
ctx->SetOutputDim("FloatStatus", {8});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
class AllocFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("FloatStatus",
"(Tensor) of shape {8} that holds the float status.");
AddComment(R"DOC(
Produces a float Tensor that holds the float status
)DOC");
}
};
template <typename DeviceContext, typename T>
class AllocFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Operator alloc_float_status is not supported on CPU"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
alloc_float_status, ops::AllocFloatStatusOp, ops::AllocFloatStatusMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(alloc_float_status,
ops::AllocFloatStatusKernel<CPU, float>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class AllocFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* float_status = ctx.Output<framework::Tensor>("FloatStatus");
float_status->mutable_data<T>(ctx.GetPlace());
auto runner = NpuOpRunner("NPUAllocFloatStatus", {}, {*float_status});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
alloc_float_status,
ops::AllocFloatStatusKernel<paddle::platform::NPUDeviceContext, float>);
......@@ -60,6 +60,12 @@ class CheckFiniteAndUnscaleOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Scale",
"(Tensor) 1-dim tensor, the scale of check_finite_and_unscale "
"operator.");
#ifdef PADDLE_WITH_ASCEND_CL
AddInput("FloatStatus",
"(Tensor) 1-dim tensor of shape [8], allocated by "
"alloc_float_status op")
.AsDispensable();
#endif
AddOutput("Out",
"(Tensors) The scaled output tensor of "
"check_finite_and_unscale operator.")
......
......@@ -24,12 +24,19 @@ namespace operators {
using Tensor = framework::Tensor;
// NOTE(zhiqiu): The CheckFiniteAndUnscaleNPUKernel is different from CUDA.
// On NPU, we do not really check the data of input tensors,
// but use NPUGetFloatStatus to check whether the nan/inf occurs on device,
// and clear it after this op.
// Which may leads to wrong result if the input tensors is not calculated
// on NPU device, but got from other way, for example, feeding.
template <typename T>
class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const auto xs = ctx.MultiInput<framework::Tensor>("X");
const auto* scale = ctx.Input<framework::Tensor>("Scale");
const auto* float_status = ctx.Input<framework::Tensor>("FloatStatus");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
......@@ -56,58 +63,60 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
runner_inverse.Run(stream);
tmp_inverse_out = &inverse_out;
size_t x_size = xs.size();
for (size_t i = 0; i < x_size; ++i) {
// NOTE(zhiqiu):
Tensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
// NOTE(zhiqiu): NPUGetFloatStatus updates data on input in-place.
// tmp is only placeholder.
auto runner_float_status =
NpuOpRunner("NPUGetFloatStatus", {*float_status}, {tmp},
{{"message", std::string("check_nan_and_inf")}});
runner_float_status.Run(stream);
Tensor sum;
sum.mutable_data<float>({1}, ctx.GetPlace());
auto runner_reduce_sum =
NpuOpRunner("ReduceSumD", {*float_status}, {sum},
{{"axes", std::vector<int>{0}}, {"keep_dims", true}});
runner_reduce_sum.Run(stream);
std::vector<float> sum_vec;
TensorToVector(
sum, ctx.template device_context<paddle::platform::NPUDeviceContext>(),
&sum_vec);
found_inf_data = (sum_vec[0] > 1);
VLOG(4) << "found_inf_data:" << found_inf_data;
for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(ctx.GetPlace());
// step2: CheckNumerics
// CheckNumerics runs on the Ascend AI CPU, which delivers poor
// performance.
Tensor check_xout(x->type());
check_xout.Resize(x->dims());
check_xout.mutable_data<T>(ctx.GetPlace());
try {
auto runner_checknumerics =
NpuOpRunner("CheckNumerics", {*x}, {check_xout},
{{"message", std::string("check_nan_and_inf")}});
runner_checknumerics.Run(stream);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
} catch (platform::EnforceNotMet& exception) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
} catch (...) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
}
if (!found_inf_data) {
// MatMul
auto runner_matmul =
NpuOpRunner("Mul", {*x, *tmp_inverse_out}, {*out}, {});
runner_matmul.Run(stream);
} else {
// ZerosLike
auto runner_zeroslike = NpuOpRunner("ZerosLike", {*x}, {*out}, {});
runner_zeroslike.Run(stream);
} // end if
} // end for
}
}
// set found_inf to true
if (found_inf_data) {
Tensor found_inf_tensor;
found_inf_tensor.Resize({1});
bool* is_found_inf =
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
*is_found_inf = true;
framework::TensorCopy(
found_inf_tensor, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), found_inf);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
}
VLOG(4) << "found overflow:" << found_inf_data;
Tensor found_inf_tensor;
found_inf_tensor.Resize({1});
bool* is_found_inf =
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
*is_found_inf = found_inf_data;
framework::TensorCopy(
found_inf_tensor, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), found_inf);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
auto runner_clear_status =
NpuOpRunner("NPUClearFloatStatus", {*float_status}, {tmp});
runner_clear_status.Run(stream);
}
};
......
......@@ -34,6 +34,8 @@ class ScaleNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
float _power = 1.0;
VLOG(4) << "scale:" << scale << ", bias:" << bias
<< " ,bias_after_scale:" << bias_after_scale;
if (bias_after_scale) {
out->mutable_data<T>(ctx.GetPlace());
auto runner =
......
......@@ -20,7 +20,7 @@ from paddle.fluid import core
__all__ = ['check_finite_and_unscale', 'update_loss_scaling']
def check_finite_and_unscale(x, scale, name=None):
def check_finite_and_unscale(x, scale, name=None, float_status=None):
"""
Check if input X contains all finite data, if yes, scale it by input Scale.
......@@ -30,9 +30,11 @@ def check_finite_and_unscale(x, scale, name=None):
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False).
Args:
x(list|tuple): The input tensors of check_finite_and_unscale operator.
scale: The scale of check_finite_and_unscale operator.
float_status(Tensor): (Only used on NPU) The float status to check overflow.
"""
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x:
......@@ -43,6 +45,11 @@ def check_finite_and_unscale(x, scale, name=None):
found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status",
['float16', 'float32'],
'check_finite_and_unscale')
inputs['FloatStatus'] = float_status
outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(
type='check_finite_and_unscale', inputs=inputs, outputs=outputs)
......
......@@ -29,6 +29,7 @@ from .amp_nn import check_finite_and_unscale
from .amp_nn import update_loss_scaling
import types
import warnings
import paddle
__all__ = ["decorate"]
......@@ -165,6 +166,17 @@ class OptimizerWithMixedPrecision(object):
train_program = loss.block.program
self._train_program = train_program
# NOTE(zhiqiu): _float_status is only used for NPU.
if core.is_compiled_with_npu():
float_status = paddle.static.data(
name="float_status", shape=[8], dtype='float32')
self._train_program.global_block().append_op(
type="alloc_float_status",
outputs={"FloatStatus": float_status}, )
self._float_status = float_status
else:
self._float_status = None
with program_guard(self._train_program, startup_program):
self._init_amp_var()
......@@ -294,7 +306,10 @@ class OptimizerWithMixedPrecision(object):
for p, g in params_grads:
with self._train_program._optimized_guard([p, g]):
_, found_inf = check_finite_and_unscale(
[g, ], self._loss_scaling, name="find_infinite_scale")
[g, ],
self._loss_scaling,
name="find_infinite_scale",
float_status=self._float_status)
found_infs.append(found_inf)
elif self._use_pure_fp16:
if fp32_grads:
......@@ -302,19 +317,24 @@ class OptimizerWithMixedPrecision(object):
_, fp32_found_inf = check_finite_and_unscale(
fp32_grads,
self._loss_scaling,
name="find_infinite_scale_fp32")
name="find_infinite_scale_fp32",
float_status=self._float_status)
found_infs.append(fp32_found_inf)
if fp16_grads:
with self._train_program._optimized_guard(fp16_grads):
_, fp16_found_inf = check_finite_and_unscale(
fp16_grads,
self._loss_scaling,
name="find_infinite_scale_fp16")
name="find_infinite_scale_fp16",
float_status=self._float_status)
found_infs.append(fp16_found_inf)
else:
with self._train_program._optimized_guard(grads):
_, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale")
grads,
self._loss_scaling,
name="find_infinite_scale",
float_status=self._float_status)
if self._use_dynamic_loss_scaling:
if self._is_distributed or self._use_pure_fp16:
......@@ -394,6 +414,7 @@ class OptimizerWithMixedPrecision(object):
The scaled loss by scaling factor, the list of optimize ops, and a
list of scaled parameters and gradients.
"""
opt_dict = self._optimizer.__class__.__dict__
if 'minimize' in opt_dict and isinstance(opt_dict['minimize'],
types.FunctionType):
......
......@@ -19,106 +19,128 @@ sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.contrib.mixed_precision.amp_nn import check_finite_and_unscale
paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestCheckFiniteAndUnscaleOp(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "check_finite_and_unscale"
self.place = paddle.NPUPlace(0)
self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype)
scale = np.random.random((1)).astype(self.dtype)
self.inputs = {'X': [('x0', x)], 'Scale': scale}
self.outputs = {
'FoundInfinite': np.array([0]),
'Out': [('out0', x / scale)],
}
def set_npu(self):
self.__class__.use_npu = True
def init_kernel_type(self):
self.use_mkldnn = False
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestCheckFiniteAndUnscaleOpWithNan(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "check_finite_and_unscale"
self.place = paddle.NPUPlace(0)
self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype)
x[128][128] = np.nan
scale = np.random.random((1)).astype(self.dtype)
self.inputs = {'X': [('x0', x)], 'Scale': scale}
self.outputs = {
'FoundInfinite': np.array([1]),
'Out': [('out0', x)],
}
def set_npu(self):
self.__class__.use_npu = True
def init_kernel_type(self):
self.use_mkldnn = False
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
# When input contains nan, do not check the output,
# since the output may be nondeterministic and will be discarded.
self.check_output_with_place(
self.place, check_dygraph=False, no_check_set=['Out'])
class TestCheckFiniteAndUnscale(unittest.TestCase):
def get_prog(self):
paddle.enable_static()
main_program = paddle.static.Program()
with program_guard(main_program):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
scale = paddle.static.data(name="scale", shape=[1], dtype='float32')
float_status = paddle.static.data(
name="status", shape=[8], dtype='float32')
main_program.global_block().append_op(
type="alloc_float_status",
outputs={"FloatStatus": float_status}, )
c = paddle.fluid.layers.elementwise_div(a, b)
out, found_inf = check_finite_and_unscale(
[c], scale, float_status=float_status)
return main_program, out, found_inf, float_status
def run_prog(self, a, b, scale):
main_program, out, found_inf, float_status = self.get_prog()
place = fluid.NPUPlace(0)
exe = fluid.Executor(place)
out_, founf_inf_, float_status_ = exe.run(
main_program,
feed={"a": a,
"b": b,
"scale": scale},
fetch_list=[out, found_inf, float_status])
print(float_status_)
return out_, founf_inf_
def test_contains_nan(self):
a = np.zeros((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
scale = np.array([2.0]).astype('float32')
out, found_inf = self.run_prog(a, b, scale)
print(out, found_inf)
self.assertTrue(found_inf[0])
def test_contains_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
scale = np.array([2.0]).astype('float32')
out, found_inf = self.run_prog(a, b, scale)
print(out, found_inf)
self.assertTrue(found_inf[0])
def test_not_contains_nan_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.ones((32, 32)).astype('float32')
scale = np.array([2.0]).astype('float32')
out, found_inf = self.run_prog(a, b, scale)
print(out, found_inf)
self.assertTrue(np.allclose(out, (a / b) / scale[0]))
self.assertFalse(found_inf[0])
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestCheckFiniteAndUnscaleOpWithInf(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "check_finite_and_unscale"
self.place = paddle.NPUPlace(0)
self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype)
x[128][128] = np.inf
scale = np.random.random((1)).astype(self.dtype)
self.inputs = {'X': [('x0', x)], 'Scale': scale}
self.outputs = {
'FoundInfinite': np.array([1]),
'Out': [('out0', x)],
}
def set_npu(self):
self.__class__.use_npu = True
def init_kernel_type(self):
self.use_mkldnn = False
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
# When input contains inf, do not check the output,
# since the output may be nondeterministic and will be discarded.
self.check_output_with_place(
self.place, check_dygraph=False, no_check_set=['Out'])
class TestCheckFiniteAndUnscaleClearFloatStatus(unittest.TestCase):
def get_prog(self):
paddle.enable_static()
main_program = paddle.static.Program()
with program_guard(main_program):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
scale = paddle.static.data(name="scale", shape=[1], dtype='float32')
float_status = paddle.static.data(
name="status", shape=[8], dtype='float32')
main_program.global_block().append_op(
type="alloc_float_status",
outputs={"FloatStatus": float_status}, )
c = paddle.fluid.layers.elementwise_div(a, b)
out, found_inf = check_finite_and_unscale(
[c], scale, float_status=float_status)
main_program.global_block().append_op(
type="alloc_float_status",
outputs={"FloatStatus": float_status}, )
d = paddle.fluid.layers.elementwise_add(a, b)
out, found_inf = check_finite_and_unscale(
[d], scale, float_status=float_status)
return main_program, out, found_inf, float_status
def run_prog(self, a, b, scale):
main_program, out, found_inf, float_status = self.get_prog()
place = fluid.NPUPlace(0)
exe = fluid.Executor(place)
out_, founf_inf_, float_status_ = exe.run(
main_program,
feed={"a": a,
"b": b,
"scale": scale},
fetch_list=[out, found_inf, float_status])
print(float_status_)
return out_, founf_inf_
def test_not_contains_nan_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
scale = np.array([2.0]).astype('float32')
out, found_inf = self.run_prog(a, b, scale)
print(out, found_inf)
self.assertTrue(np.allclose(out, (a + b) / scale[0]))
self.assertFalse(found_inf[0])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册