diff --git a/dnn/src/common/reduce_helper_device.h b/dnn/src/common/reduce_helper_device.h index d0edade7a25fee20c19dab4ecb776c7f46be50ab..3799bcfeb9f09f2a0e54fe1815ee95f266a2accc 100644 --- a/dnn/src/common/reduce_helper_device.h +++ b/dnn/src/common/reduce_helper_device.h @@ -175,13 +175,13 @@ struct MaxOp { : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; -template +template struct CheckNonFiniteOp { typedef wtype_ wtype; const wtype INIT; src_ctype** srcs; - index_ctype* srcs_total_nr_elems; + size_t* srcs_total_nr_elems; dst_ctype* dst; const size_t B; const src_ctype scale; @@ -206,7 +206,7 @@ struct CheckNonFiniteOp { return lhs | rhs; } MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp( - src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst, + src_ctype** srcs, size_t* srcs_total_nr_elems, dst_ctype* dst, size_t B, src_ctype scale) : INIT(wtype(0)), srcs(srcs), diff --git a/dnn/src/cuda/check_non_finite/kern.cu b/dnn/src/cuda/check_non_finite/kern.cu index b22e84cfabdacc595e0110e8444fa8b9d63a7a0e..0a46daa1a4fbc5bb69003d2a2d6e965286fdd658 100644 --- a/dnn/src/cuda/check_non_finite/kern.cu +++ b/dnn/src/cuda/check_non_finite/kern.cu @@ -8,10 +8,10 @@ namespace cuda { #define COMMA , -#define cb(_dtype) \ - INST_REDUCE( \ - device_reduce::CheckNonFiniteOp< \ - _dtype COMMA size_t COMMA dt_int32 COMMA dt_int32>, \ +#define cb(_dtype) \ + INST_REDUCE( \ + device_reduce::CheckNonFiniteOp< \ + _dtype COMMA dt_float32 COMMA dt_int32 COMMA dt_int32>, \ false); cb(dt_float32); diff --git a/dnn/src/cuda/check_non_finite/opr_impl.cpp b/dnn/src/cuda/check_non_finite/opr_impl.cpp index 20dfda3462bb08151d1a32c0272c736e3ee3a77c..3292854725c8157d751b9b54ed61d5487ad1ef1c 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.cpp +++ b/dnn/src/cuda/check_non_finite/opr_impl.cpp @@ -10,11 +10,11 @@ namespace megdnn { namespace cuda { using device_reduce::CheckNonFiniteOp; -#define total_nr_elems_max 2048 +#define total_nr_elems_max 8192 template size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes - typedef CheckNonFiniteOp Op; + typedef CheckNonFiniteOp Op; megdnn_assert(m_size > 0); WorkspaceBundle bundle( nullptr, { @@ -59,7 +59,7 @@ void CheckNonFiniteImpl::_exec( _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(srcs, dst, workspace.size); - typedef CheckNonFiniteOp Op; + typedef CheckNonFiniteOp Op; auto stream = cuda_stream(this->handle()); SmallVector workspace_sizes{ sizeof(T*) * m_size, @@ -102,7 +102,7 @@ void CheckNonFiniteImpl::_exec( cuda_check(cudaStreamAddCallback( stream, callback_free, static_cast(workspace_cpu_raw), 0)); - return run_reduce( + run_reduce( static_cast( (void*)((char*)workspace_gpu_raw + workspace_gpu.total_size_in_bytes())), diff --git a/imperative/python/megengine/amp/grad_scaler.py b/imperative/python/megengine/amp/grad_scaler.py index f4ca6cce819de4b098fa90f84b6914e3d4cbb29e..337c8f6432f111193361bcd3294b66028802be8d 100644 --- a/imperative/python/megengine/amp/grad_scaler.py +++ b/imperative/python/megengine/amp/grad_scaler.py @@ -141,8 +141,10 @@ class GradScaler: tensor.grad = None return self - def _check_gradients(self, grad, scale): - return _check_non_finite(grad, scale) + def _check_gradients(self, grads, scale): + if len(grads) == 0: + return False + return _check_non_finite(grads, scale) def update(self, new_scale: float = None): r"""Update the scale factor according to whether encountered overflow grad. diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index c414b9b075f4ee485f450c14563bcb9bd3388928..b8c384d0286d35c1ed49ad58a04d43e18ce275c9 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -691,11 +691,13 @@ def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: r"""Check whether input contains infinite or nan value. Args: - inp: a tensor to be checked. + inps: tensors to be checked. Returns: a int32 scalar tensor, 0 for False and 1 for True. """ + if isinstance(inps, Tensor): + inps = [inps] op = builtin.CheckNonFinite(scale=scale) oups = apply(op, *inps) out = oups[-1] diff --git a/imperative/python/test/unit/amp/test_grad_scaler.py b/imperative/python/test/unit/amp/test_grad_scaler.py index 4eb7a4070e361e25779b95bc2bcb24b98f9455dd..d7c334149faace4527fa421b4d95b5016ede1f2e 100644 --- a/imperative/python/test/unit/amp/test_grad_scaler.py +++ b/imperative/python/test/unit/amp/test_grad_scaler.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import megengine as mge from megengine.amp import GradScaler @@ -6,23 +7,46 @@ from megengine.autodiff import GradManager from megengine.jit import trace -def test_grad_scaler(): - def f(): - gm = GradManager() - scaler = GradScaler() - - x = mge.tensor(1.0) - for _ in range(3): - with gm: - y = x + 1 - gm.attach(y) - loss = y + 1 - scaler.backward(gm, loss, unscale_grad=False) - np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor) - scaler.unscale(gm.attached_tensors()) - np.testing.assert_equal(y.grad.numpy(), 1) - # test handle None elements - scaler.unscale(gm.attached_tensors()) - - f() - trace(f)() +@pytest.mark.parametrize( + "is_trace", [False, True], +) +def test_grad_scaler(is_trace): + gm = GradManager() + scaler = GradScaler() + + def f(idx, data, calc): + x = mge.tensor(data, no_cache=True) + y = mge.tensor(data, no_cache=True) + + if is_trace: + calc = trace(calc) + + gm.attach([x, y]) + with gm: + loss = calc(x, y) + scaler.backward(gm, loss, unscale_grad=False) + np.testing.assert_equal(x.grad.numpy(), 2 * scaler.scale_factor) + scaler.unscale(filter(lambda t: t.grad is not None, gm.attached_tensors())) + # scaler.unscale(gm.attached_tensors()) + np.testing.assert_equal(x.grad.numpy(), 2) + + def double_variables(x, y): + z = x + 2 * y + loss = 2 * z + 1 + return loss + + def single_variable(x, y): + z = x + 1 + loss = 2 * z + 1 + return loss + + # need grad being unique storage or not inplace modifying grad + def double_variables_with_same_grad(x, y): + z = x + y + loss = 2 * z + 1 + return loss + + for data in [np.random.random((1, 2, 3, 4)), 1.0]: + for calc in [double_variables, single_variable, double_variables_with_same_grad]: + for idx in range(3): + f(idx, data, calc)