diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index d6a73d1e18a5f3f12dd04bfd24e3065940f96a04..cc1c5f416ed5617c0329311c88789e0b72755e96 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1442,6 +1442,39 @@ protected: void backward_check_exec(const TensorLayout& src, const TensorLayout& dst); }; +class LAMBUpdate : public OperatorBase { + DEF_OPR_PARAM(LAMBUpdate); + // input=(m_t-1,v_t-1,lamb_param,grad) , output = (m_t,v_t,new_param) + DEF_OPR_IMPL(LAMBUpdate, OperatorBase, 4, 3); + +public: + virtual void exec( + _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, + _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, + _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, + _megdnn_tensor_out new_param, _megdnn_workspace workspace) = 0; + + virtual size_t get_workspace_in_bytes( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, + const TensorLayout& m_t, const TensorLayout& v_t, + const TensorLayout& new_param) = 0; + + void deduce_layout( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, + TensorLayout& v_t, TensorLayout& new_param); + +protected: + void check_exec( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, + const TensorLayout& m_t, const TensorLayout& v_t, + const TensorLayout& new_param, size_t workspace_in_bytes); +}; + +using LAMB = LAMBUpdate; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 72940d61a2ce7f13f7ec391085d8b6695e6c0430..a49e65b9efe59645d92c78cf0b2bbc4aedf64f2d 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -36,13 +36,13 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', - 'NCHW44 = 7','NCHW44_DOT = 8', + 'NCHW44 = 7','NCHW44_DOT = 8', Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), - Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), - Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), - Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), + Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), + Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), + Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' 'output tensor is nchw layout'), Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' @@ -96,9 +96,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', - 'NCHW44 = 7','NCHW44_DOT = 8', - Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), - Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), + 'NCHW44 = 7','NCHW44_DOT = 8', + Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), + Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' 'output tensor is nchw layout'), @@ -107,9 +107,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' 'output tensor is nchw4 layout, padding c=4'), Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' - 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), + 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' - 'instructions for 4-bit integers on Nvidia platforms'), + 'instructions for 4-bit integers on Nvidia platforms'), Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') ) @@ -1038,10 +1038,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o 'NCHW_NCHW4 = 24', 'NCHW4_NCHW = 25', 'NCHW_NCHW4_WEIGHT = 26', - 'NCHW_NCHW64 = 27', - 'NCHW64_NCHW = 28', - 'NCHW_NHWC = 29', - 'NHWC_NCHW = 30', + 'NCHW_NCHW64 = 27', + 'NCHW64_NCHW = 28', + 'NCHW_NHWC = 29', + 'NHWC_NCHW = 30', 'NHWCD4I_NHWC = 31', ) ) @@ -1264,3 +1264,14 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f'). add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') ) + +(pdef('LAMBUpdate'). + add_fields('float32', Doc('beta_1', 'beta_1 paramter of lamb'), '1.f'). + add_fields('float32', Doc('beta_2', 'beta_2 paramter of lamb'), '1.f'). + add_fields('float32', Doc('step', 'training step'), '1.f'). + add_fields('float32', Doc('lr', 'learning rate'), '1.f'). + add_fields('float32', Doc('weight_decay', 'weight decay to adjust learning rate'), '1.f'). + add_fields('float32', Doc('eps', 'eps to multi'), '1.f'). + add_fields('bool', Doc('bias_correction', 'whether correct bias'), 'true'). + add_fields('bool', Doc('always_adapt', 'apply adaptive lr to 0.0'), 'false') +) diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index d59dbb7805128a8c30bf0ccfcf778afd4c0fbd55..709199f006169e6827a68b2bf5c03e13ee866ce5 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -209,6 +209,7 @@ private: cb(RNN) \ cb(RNNBackward) \ cb(LSTM) \ + cb(LAMBUpdate) \ cb(LSTMBackward) \ cb(SoftmaxForward) \ cb(SoftmaxBackward) diff --git a/dnn/src/common/lamb.cpp b/dnn/src/common/lamb.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f837e80062e02173ecfc778ff25321dcd08b7b48 --- /dev/null +++ b/dnn/src/common/lamb.cpp @@ -0,0 +1,25 @@ +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void LAMBUpdate::deduce_layout( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, + TensorLayout& v_t, TensorLayout& new_param) { + m_t = TensorLayout(m_t_1); + v_t = TensorLayout(v_t_1); + new_param = TensorLayout(lamb_param); + MEGDNN_MARK_USED_VAR(grad); +} + +void LAMBUpdate::check_exec( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, + const TensorLayout& m_t, const TensorLayout& v_t, const TensorLayout& new_param, + size_t workspace_in_bytes) { + auto required_workspace_in_bytes = + get_workspace_in_bytes(m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} +} // namespace megdnn diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 671d45163dbe14665f20923362633bdeafc2b24e..1d875a0c1f9101ecfb82a42ff5386ddaec973404 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -127,6 +127,7 @@ DEF(LSQBackward, 7, true, false); DEF(Fill, 1, true, false); DEF(LayerNormForward, 6, true, true); DEF(LayerNormBackward, 8, true, true); +DEF(LAMBUpdate, 7, true, true); DEF(DropoutForward, 3, true, true); DEF(DropoutBackward, 3, true, true); DEF(RNNCellForward, 7, true, true); diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 7fbd005fb3642a3cede676fd8c2f2763cb7ebe5a..4cc6b0855e151ddc74e17a6cd63be01b73a0c3b6 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -35,6 +35,7 @@ #include "src/cuda/images2neibs/opr_impl.h" #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" #include "src/cuda/indexing_one_hot/opr_impl.h" +#include "src/cuda/lamb/opr_impl.h" #include "src/cuda/layer_norm/opr_impl.h" #include "src/cuda/linspace/opr_impl.h" #include "src/cuda/local/opr_impl.h" @@ -210,6 +211,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(LAMBUpdate); MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); diff --git a/dnn/src/cuda/lamb/lamb_cuda.cu b/dnn/src/cuda/lamb/lamb_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b0407e65d9f312919ad34576e767ca2ba7a0533 --- /dev/null +++ b/dnn/src/cuda/lamb/lamb_cuda.cu @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include +#include "megdnn/arch.h" +#include "megdnn/dtype.h" +#include "src/cuda/cuda_shfl_compat.cuh" +#include "src/cuda/lamb/lamb_cuda.cuh" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace lamb { + +template +struct square { + __host__ __device__ T operator()(const T& x) const { return x * x; } +}; + +template +__global__ void update_kernal_1( + T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, + T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, + float weight_decay, float eps, bool bias_correction, bool always_adapt, + size_t total_nr_elem) { + size_t idx = threadIdx.x + blockIdx.x * blockDim.x; + T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, + bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; + if (idx < total_nr_elem) { + m_t[idx] = beta_1 * m_t_1[idx] + (1 - beta_1) * static_cast(grad[idx]); + v_t[idx] = beta_2 * v_t_1[idx] + + (1 - beta_2) * std::pow(static_cast(grad[idx]), 2); + rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); + if (weight_decay != 0) { + rt[idx] += lamb_param[idx] * weight_decay; + } + } +} + +template +__global__ void update_kernal_2( + T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, + T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, + float weight_decay, float eps, bool bias_correction, bool always_adapt, + size_t total_nr_elem, T_ACC trust_ratio) { + size_t idx = threadIdx.x + blockIdx.x * blockDim.x; + T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, + bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; + if (idx < total_nr_elem) { + rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); + if (weight_decay != 0) { + rt[idx] += lamb_param[idx] * weight_decay; + } + new_param[idx] = lamb_param[idx] - lr * trust_ratio * rt[idx]; + } +} + +template +void update( + T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, + T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, + float weight_decay, float eps, bool bias_correction, bool always_adapt, + size_t total_nr_elem, cudaStream_t stream) { + size_t NR_BLOCKS = DIVUP(total_nr_elem, NR_THREADS); + update_kernal_1<<>>( + m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, + step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem); + after_kernel_launch(); + thrust::device_ptr lamb_param_ptr(lamb_param); + thrust::device_ptr rt_ptr(rt); + square unary_op; + thrust::plus binary_op; + T_ACC p_norm = std::sqrt(thrust::transform_reduce( + lamb_param_ptr, lamb_param_ptr + total_nr_elem, unary_op, 0.f, binary_op)); + T_ACC d_norm = std::sqrt(thrust::transform_reduce( + rt_ptr, rt_ptr + total_nr_elem, unary_op, 0.f, binary_op)); + T_ACC trust_ratio = 1; + if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { + trust_ratio = p_norm / d_norm; + } + + update_kernal_2<<>>( + m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, + step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem, + trust_ratio); + after_kernel_launch(); +} + +#define INST(T, T_ACC) \ + template void update( \ + T_ACC*, T_ACC*, T_ACC*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC*, float, float, \ + float, float, float, float, bool, bool, size_t, cudaStream_t); + +INST(dt_float32, dt_float32) +INST(dt_float16, dt_float32) +INST(dt_bfloat16, dt_float32) +#undef INST + +} // namespace lamb +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/lamb/lamb_cuda.cuh b/dnn/src/cuda/lamb/lamb_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..223605f8ad088da27f25418c9980f682e278c2fe --- /dev/null +++ b/dnn/src/cuda/lamb/lamb_cuda.cuh @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace megdnn { +namespace cuda { +namespace lamb { + +template +void update( + T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, + T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, + float weight_decay, float eps, bool bias_correction, bool always_adapt, + size_t total_nr_elem, cudaStream_t stream); + +} // namespace lamb +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/lamb/opr_impl.cpp b/dnn/src/cuda/lamb/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed5bdea90fc3fde126c9f96434e0f3d5570a2d02 --- /dev/null +++ b/dnn/src/cuda/lamb/opr_impl.cpp @@ -0,0 +1,45 @@ +#include "src/cuda/lamb/opr_impl.h" +#include "./lamb_cuda.cuh" +#include "src/cuda/utils.h" + +#include +#include +#include +namespace megdnn { +namespace cuda { +void LAMBUpdateImpl::exec( + _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, + _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, + _megdnn_tensor_out new_param, _megdnn_workspace workspace) { + auto p = param(); + float beta_1 = p.beta_1; + float beta_2 = p.beta_2; + float step = p.step; + float lr = p.lr; + float weight_decay = p.weight_decay; + float eps = p.eps; + bool bias_correction = p.bias_correction; + bool always_adapt = p.always_adapt; + size_t total_elem = lamb_param.layout.total_nr_elems(); + auto stream = cuda_stream(handle()); + using namespace ::megdnn::cuda::lamb; + +#define cb(DType) \ + if (grad.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + update( \ + m_t_1.ptr(), v_t_1.ptr(), lamb_param.ptr(), \ + grad.ptr(), m_t.ptr(), v_t.ptr(), \ + new_param.ptr(), workspace.ptr(), beta_1, beta_2, step, \ + lr, weight_decay, eps, bias_correction, always_adapt, total_elem, \ + stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/lamb/opr_impl.h b/dnn/src/cuda/lamb/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..608d9fe6aeb02121f78b474bf7bda979c194bfe6 --- /dev/null +++ b/dnn/src/cuda/lamb/opr_impl.h @@ -0,0 +1,25 @@ +#pragma once +#include "megdnn/oprs.h" + +#include "src/cuda/cudnn_wrapper.h" +namespace megdnn { +namespace cuda { +class LAMBUpdateImpl final : public LAMBUpdate { +public: + using LAMBUpdate::LAMBUpdate; + void exec( + _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, + _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, + _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, + _megdnn_tensor_out new_param, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, + const TensorLayout& m_t, const TensorLayout& v_t, + const TensorLayout& new_param) override { + return m_t.access_bytes(); + }; +}; +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 98309f01ccb82cfe4a62b6248ce4a194207358f3..5bedcd3a9cbbe1445aa401968448ceb038f1a5f4 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -37,6 +37,7 @@ #include "src/naive/images2neibs/opr_impl.h" #include "src/naive/indexing_multi_axis_vec/opr_impl.h" #include "src/naive/indexing_one_hot/opr_impl.h" +#include "src/naive/lamb/opr_impl.h" #include "src/naive/layer_norm/opr_impl.h" #include "src/naive/linspace/opr_impl.h" #include "src/naive/local/opr_impl.h" diff --git a/dnn/src/naive/lamb/opr_impl.cpp b/dnn/src/naive/lamb/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b07b83fde41973e1cb02a59b56d35ddd2c03c694 --- /dev/null +++ b/dnn/src/naive/lamb/opr_impl.cpp @@ -0,0 +1,89 @@ +#include "src/naive/lamb/opr_impl.h" +#include +#include +#include +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; + +namespace { +using Param = megdnn::LAMBUpdate::Param; + +template +void update( + _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, + _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, + _megdnn_tensor_out new_param, const Param& param) { + float beta_1 = param.beta_1; + float beta_2 = param.beta_2; + float step = param.step; + float lr = param.lr; + float weight_decay = param.weight_decay; + float eps = param.eps; + bool bias_correction = param.bias_correction; + bool always_adapt = param.always_adapt; + + size_t total_elem = lamb_param.layout.total_nr_elems(); + T_ACC mt, vt, bc_1, bc_2, rt, d_norm = 0; + bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1; + bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; + + for (size_t i = 0; i < total_elem; i++) { + mt = m_t.ptr()[i] = beta_1 * m_t_1.ptr()[i] + + (1 - beta_1) * static_cast(grad.ptr()[i]); + vt = v_t.ptr()[i] = + beta_2 * v_t_1.ptr()[i] + + (1 - beta_2) * std::pow(static_cast(grad.ptr()[i]), 2); + rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); + if (weight_decay != 0) { + rt += lamb_param.ptr()[i] * weight_decay; + } + d_norm += rt * rt; + } + d_norm = sqrt(d_norm); + auto get_norm = [=](_megdnn_tensor_in norm) -> T_ACC { + return sqrt(std::accumulate( + norm.ptr(), norm.ptr() + total_elem, 0, + [](T_ACC t1, T_ACC t2) -> T_ACC { return t1 + t2 * t2; })); + }; + T_ACC p_norm = get_norm(lamb_param), trust_ratio = 1; + if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { + trust_ratio = p_norm / d_norm; + } + for (size_t i = 0; i < total_elem; i++) { + mt = m_t.ptr()[i]; + vt = v_t.ptr()[i]; + rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); + if (weight_decay != 0) { + rt += lamb_param.ptr()[i] * weight_decay; + } + new_param.ptr()[i] = lamb_param.ptr()[i] - lr * trust_ratio * rt; + } +} + +} // namespace + +namespace megdnn { +namespace naive { +void LAMBUpdateImpl::exec( + _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, + _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, + _megdnn_tensor_out new_param, _megdnn_workspace workspace) { + check_exec( + m_t_1.layout, v_t_1.layout, lamb_param.layout, grad.layout, m_t.layout, + v_t.layout, new_param.layout, workspace.size); +#define cb(DType) \ + if (grad.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR(update::ctype>( \ + m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} +} // namespace naive + +} // namespace megdnn diff --git a/dnn/src/naive/lamb/opr_impl.h b/dnn/src/naive/lamb/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..3f1b137fd8f09bc5201ff69925e6953aadb2634c --- /dev/null +++ b/dnn/src/naive/lamb/opr_impl.h @@ -0,0 +1,34 @@ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace naive { + +class LAMBUpdateImpl final : public LAMBUpdate { +public: + using LAMBUpdate::LAMBUpdate; + void exec( + _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, + _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, + _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, + _megdnn_tensor_out new_param, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& m_t_1, const TensorLayout& v_t_1, + const TensorLayout& lamb_param, const TensorLayout& grad, + const TensorLayout& m_t, const TensorLayout& v_t, + const TensorLayout& new_param) override { + MEGDNN_MARK_USED_VAR(m_t_1); + MEGDNN_MARK_USED_VAR(v_t_1); + MEGDNN_MARK_USED_VAR(lamb_param); + MEGDNN_MARK_USED_VAR(grad); + MEGDNN_MARK_USED_VAR(m_t); + MEGDNN_MARK_USED_VAR(v_t); + MEGDNN_MARK_USED_VAR(new_param); + return 0; + }; +}; +} // namespace naive + +} // namespace megdnn diff --git a/dnn/test/common/lamb.h b/dnn/test/common/lamb.h new file mode 100644 index 0000000000000000000000000000000000000000..5b53630d545a2625e636dde444b5da4950685b46 --- /dev/null +++ b/dnn/test/common/lamb.h @@ -0,0 +1,36 @@ +#pragma once +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace lamb { + +struct TestArg { + param::LAMBUpdate param; + TensorShape src; + TestArg(param::LAMBUpdate param, TensorShape src) : param(param), src(src) {} +}; + +inline std::vector get_args() { + std::vector args; + param::LAMBUpdate cur_param; + cur_param.beta_1 = 0.9; + cur_param.beta_2 = 0.999; + cur_param.eps = 1e-8; + cur_param.weight_decay = 0; + cur_param.lr = 6.25e-5; + cur_param.bias_correction = true; + cur_param.always_adapt = false; + args.emplace_back( + cur_param, TensorShape{ + 1280, + }); + args.emplace_back(cur_param, TensorShape{1280, 1280}); + args.emplace_back(cur_param, TensorShape{1280, 3, 224, 224}); + return args; +} + +} // namespace lamb +} // namespace test +} // namespace megdnn diff --git a/dnn/test/cuda/lamb.cpp b/dnn/test/cuda/lamb.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1c106b80b78f93b6d3d7b9cb95f248d227067a1 --- /dev/null +++ b/dnn/test/cuda/lamb.cpp @@ -0,0 +1,44 @@ +#include "test/cuda/fixture.h" + +#include "test/common/checker.h" +#include "test/common/rng.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, LAMBUpdate) { + LAMBUpdate::Param param; + param.beta_1 = 0.9; + param.beta_2 = 0.999; + param.eps = 1e-5; + param.weight_decay = 0.4; + param.lr = 1e-3; + param.step = 1; + param.bias_correction = true; + param.always_adapt = false; + + Checker checker(handle_cuda()); + checker.set_epsilon(1e-3); + UniformFloatRNG rng0(0, 1); + + auto run = [&](DType d) { + checker.set_param(param) + .set_rng(0, &rng0) + .set_rng(1, &rng0) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(3, d) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, dtype::Float32()) + .set_dtype(6, dtype::Float32()) + .execs({{2}, {2}, {2}, {2}, {}, {}, {}}); + }; + + run(dtype::Float32()); + run(dtype::Float16()); + run(dtype::BFloat16()); +} + +} // namespace test +} // namespace megdnn diff --git a/dnn/test/naive/lamb.cpp b/dnn/test/naive/lamb.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23ce8497be0c0b44cd31bb113f8848a151127f7f --- /dev/null +++ b/dnn/test/naive/lamb.cpp @@ -0,0 +1,33 @@ +#include "test/common/lamb.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +using namespace megdnn; +using namespace test; + +TEST_F(NAIVE, LAMBUpdate) { + Checker checker(handle(), false); + LAMBUpdate::Param param; + param.beta_1 = 0; + param.beta_2 = 0; + param.eps = 0; + param.weight_decay = 0; + param.lr = 1; + param.step = 1; + param.bias_correction = true; + param.always_adapt = false; + + TensorND m_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); + TensorND v_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); + TensorND param_lamb = TensorValue({2}, dtype::Float32(), {1, 1}); + TensorND grad = TensorValue({2}, dtype::Float16(), {1, 1}); + + TensorND m_t = TensorValue({2}, dtype::Float32(), {1, 1}); + TensorND v_t = TensorValue({2}, dtype::Float32(), {1, 1}); + TensorND new_param = TensorValue({2}, dtype::Float32(), {0, 0}); + checker.set_param(param).exect( + Testcase{m_t_1, v_t_1, param_lamb, grad, {}, {}, {}}, + Testcase{{}, {}, {}, {}, m_t, v_t, new_param}); +} diff --git a/imperative/python/megengine/optimizer/__init__.py b/imperative/python/megengine/optimizer/__init__.py index 12539e6d35b1213d460fd455b4a25c1442cab9b2..1c69867152cdb01acd29ab983010eace25e34f8b 100644 --- a/imperative/python/megengine/optimizer/__init__.py +++ b/imperative/python/megengine/optimizer/__init__.py @@ -4,6 +4,7 @@ from .adagrad import Adagrad from .adam import Adam from .adamw import AdamW from .clip_grad import * +from .lamb import LAMB, LAMBFp16 from .lr_scheduler import LRScheduler from .multi_step_lr import MultiStepLR from .optimizer import Optimizer diff --git a/imperative/python/megengine/optimizer/lamb.py b/imperative/python/megengine/optimizer/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..9054a9e0414443ebcd2998119b5fd8dbc26fa675 --- /dev/null +++ b/imperative/python/megengine/optimizer/lamb.py @@ -0,0 +1,160 @@ +# Copyright (c) 2020 Ross Wightman +# This file has been modified by Megvii ("Megvii Modifications"). +# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +"""LAMB optimizer + +References: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py +""" +import os +from typing import Iterable, Tuple, Union + +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops.builtin import LAMBUpdate + +from .. import Parameter, tensor +from ..functional import sum +from ..functional.inplace import _inplace_add_ +from .optimizer import Optimizer + + +class LAMB(Optimizer): + r"""Implements LAMB algorithm. + + LAMB is proposed in `"Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" + `_. + + Args: + params: iterable of parameters to optimize or dicts defining parameter groups. + lr: learning rate. + betas: coefficients used for computing running averages of gradient and its square. + Default: ``(0.9, 0.999)`` + eps: term added to the denominator to improve numerical stability. Default: ``1e-8`` + bias_correction: enables bias correction by ``1 - beta ** step``. Default: ``True`` + weight_decay: weight decay (L2 penalty). Default: ``0.0`` + always_adapt: apply adaptive lr to ``0.0`` weight decay parameter. Default: ``False`` + """ + + def __init__( + self, + params: Union[Iterable[Parameter], dict], + lr: float, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + bias_correction: bool = True, + weight_decay: float = 0.0, + always_adapt: bool = False, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) + super().__init__(params, defaults) + self.bias_correction = bias_correction + self.always_adapt = always_adapt + self._disable_type_convert = True + + def _create_state(self, param_group): + for param in param_group["params"]: + self._add_state(param, "exp_avg") + self._add_state(param, "exp_avg_sq") + self._add_state(param, "step", initializer=0.0, dtype="float32") + + def _updates(self, param_group): + lr = param_group["lr"] + weight_decay = param_group["weight_decay"] + eps = param_group["eps"] + beta0, beta1 = param_group["betas"] + + # since `conver_inputs` is disabled for param updates, + # scalar should be explicitly tansforred to tensor + c1 = tensor(1.0) + + for param in param_group["params"]: + + if param.grad is None: + continue + + grad = param.grad + + states = self._state[param] + + step, exp_avg, exp_avg_sq = ( + states["step"], + states["exp_avg"], + states["exp_avg_sq"], + ) + step += c1 + + op = LAMBUpdate( + beta0, + beta1, + int(step), + lr, + weight_decay, + eps, + self.bias_correction, + self.always_adapt, + ) + + new_exp_avg, new_exp_avg_sq, new_param = apply( + op, exp_avg, exp_avg_sq, param, grad + ) + param._reset(new_param) + exp_avg._reset(new_exp_avg) + exp_avg_sq._reset(new_exp_avg_sq) + + +class LAMBFp16(LAMB): + def _create_state(self, param_group): + for param in param_group["params"]: + self._add_state(param, "exp_avg", dtype="float32") + self._add_state(param, "exp_avg_sq", dtype="float32") + self._add_state(param, "step", initializer=0.0, dtype="float32") + self._state[param]["param_fp32"] = param.astype("float32") + + def _updates(self, param_group): + lr = param_group["lr"] + weight_decay = param_group["weight_decay"] + eps = param_group["eps"] + beta0, beta1 = param_group["betas"] + c1 = tensor(1.0) + for param in param_group["params"]: + + if param.grad is None: + continue + + grad = param.grad + + states = self._state[param] + + step, exp_avg, exp_avg_sq = ( + states["step"], + states["exp_avg"], + states["exp_avg_sq"], + ) + step += c1 + fp32_param = states["param_fp32"] + op = LAMBUpdate( + beta0, + beta1, + step, + lr, + weight_decay, + eps, + self.bias_correction, + self.always_adapt, + ) + + new_exp_avg, new_exp_avg_sq, new_param = apply( + op, exp_avg, exp_avg_sq, fp32_param, grad + ) + fp32_param._reset(new_param) + param._reset(new_param.astype("float16")) + exp_avg._reset(new_exp_avg) + exp_avg_sq._reset(new_exp_avg_sq) diff --git a/imperative/python/test/unit/optimizer/test_lamb.py b/imperative/python/test/unit/optimizer/test_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfb861d033c4c8b3df67168681e6f5882c06b4b --- /dev/null +++ b/imperative/python/test/unit/optimizer/test_lamb.py @@ -0,0 +1,85 @@ +import numpy as np + +import megengine as mge +import megengine.autodiff as ad +import megengine.functional as F +import megengine.module as M +import megengine.optimizer as optim +from megengine import tensor +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops.builtin import LAMBUpdate + + +def lamb_update( + param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt +): + lr = param_group["lr"] + weight_decay = param_group["weight_decay"] + eps = param_group["eps"] + beta0, beta1 = param_group["betas"] + + # since `conver_inputs` is disabled for param updates, + # scalar should be explicitly tansforred to tensor + + _lr, _neg_lr = map(tensor, (lr, -lr)) + _weight_decay = tensor(weight_decay) + _eps = tensor(eps) + _beta0, _beta1 = map(tensor, (beta0, beta1)) + + c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0)) + + def norm(vec): + return sum(vec * vec) ** c05 + + p_norm = norm(param.flatten()) + + # step = step + c1 + step += c1 + + # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) + exp_avg *= _beta0 + exp_avg += grad * (c1 - _beta0) + + # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) + exp_avg_sq *= _beta1 + exp_avg_sq += (c1 - _beta1) * (grad * grad) + + bias_correction1 = c1 - _beta0 ** step if bias_correction else c1 + bias_correction2 = c1 - _beta1 ** step if bias_correction else c1 + delta = (exp_avg / bias_correction1) / ( + (exp_avg_sq / bias_correction2) ** c05 + _eps + ) + if weight_decay != 0.0: + delta += param * _weight_decay + + d_norm = norm(delta.flatten()) + trust_ratio = ( + p_norm / d_norm + if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0 + else c1 + ) + new_param = param - _lr * trust_ratio * delta + return exp_avg, exp_avg_sq, new_param + + +def test_lamb(): + op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False) + m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) + v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) + params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) + grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16) + (new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad) + + param_group = { + "betas": (0.9, 0.999), + "step": 1, + "lr": 1e-3, + "weight_decay": 0.4, + "eps": 1e-8, + } + gt_m_t, gt_v_t, gt_new_param = lamb_update( + param_group, 1, m_t_1, v_t_1, params, grad, True, False + ) + np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2) + np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2) + np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2) diff --git a/imperative/src/impl/ops/lamb.cpp b/imperative/src/impl/ops/lamb.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35dc22b6bfd232db6e2782b69d3943778d59f94f --- /dev/null +++ b/imperative/src/impl/ops/lamb.cpp @@ -0,0 +1,82 @@ +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/utility.h" + +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" +namespace mgb { +namespace imperative { + +namespace { +namespace lamb { + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + return layout_checker; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& input_descs) { + mgb_assert(input_descs.size() == 4, "IndexingOneHot expects 4inputs"); + + auto comp_node = input_descs[0].comp_node; + auto comp_node1 = input_descs[1].comp_node; + auto comp_node2 = input_descs[2].comp_node; + TensorLayout m_t_1 = input_descs[0].layout, v_t_1 = input_descs[1].layout, + lamb_param = input_descs[2].layout, grad = input_descs[3].layout; + + TensorLayout new_param = lamb_param, m_t = m_t_1, v_t = v_t_1; + return {{{m_t, comp_node}, {v_t, comp_node1}, {new_param, comp_node2}}, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto&& m_t_1 = inputs[0]; + auto&& v_t_1 = inputs[1]; + auto&& lamb_param = inputs[2]; + auto&& grad = inputs[3]; + + TensorLayout m_t_1_layout{m_t_1->layout()}; + TensorLayout v_t_1_layout{v_t_1->layout()}; + TensorLayout lamb_param_layout{lamb_param->layout()}; + + DeviceTensorND m_t = BlobManager::inst()->alloc_workspace_with_defrag( + m_t_1->comp_node(), m_t_1_layout); + + DeviceTensorND v_t = BlobManager::inst()->alloc_workspace_with_defrag( + v_t_1->comp_node(), v_t_1_layout); + + DeviceTensorND new_param = BlobManager::inst()->alloc_workspace_with_defrag( + lamb_param->comp_node(), lamb_param_layout); + + DnnOprCaller caller{lamb_param->comp_node()}; + TensorLayout m_layout( + {caller.op->get_workspace_in_bytes( + m_t_1->layout(), v_t_1->layout(), lamb_param->layout(), + grad->layout(), m_t.layout(), v_t.layout(), new_param.layout())}, + dtype::Byte()); + + auto dnn_workspace = caller.create_workspace(m_layout); + caller.op->param() = op.param(); + caller.op->exec( + m_t_1->dev_tensor().as_megdnn(), v_t_1->dev_tensor().as_megdnn(), + lamb_param->dev_tensor().as_megdnn(), grad->dev_tensor().as_megdnn(), + m_t.as_megdnn(), v_t.as_megdnn(), new_param.as_megdnn(), dnn_workspace); + return {Tensor::make(m_t), Tensor::make(v_t), Tensor::make(new_param)}; +} + +OP_TRAIT_REG(LAMBUpdate, LAMBUpdate) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); + +} // namespace lamb +} // namespace +} // namespace imperative +} // namespace mgb diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 7e540c93995a4a91ff69d8d94659a1012a7dc589..3100df70ee57cffd49b355cebbf65e9588bf87d7 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -477,6 +477,9 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; def LRN: MgbHashableOp<"LRN", [LRNParam]>; def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; + +def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; + def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>;