提交 98b5ee78 编写于 作者: M Megvii Engine Team

feat(mge/dnn): add lamb optimizer

GitOrigin-RevId: 5a27157456522e695aa44ec665bd28f06d6fe976
上级 a926878c
......@@ -1442,6 +1442,39 @@ protected:
void backward_check_exec(const TensorLayout& src, const TensorLayout& dst);
};
class LAMBUpdate : public OperatorBase {
DEF_OPR_PARAM(LAMBUpdate);
// input=(m_t-1,v_t-1,lamb_param,grad) , output = (m_t,v_t,new_param)
DEF_OPR_IMPL(LAMBUpdate, OperatorBase, 4, 3);
public:
virtual void exec(
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1,
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad,
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
_megdnn_tensor_out new_param, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad,
const TensorLayout& m_t, const TensorLayout& v_t,
const TensorLayout& new_param) = 0;
void deduce_layout(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t,
TensorLayout& v_t, TensorLayout& new_param);
protected:
void check_exec(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad,
const TensorLayout& m_t, const TensorLayout& v_t,
const TensorLayout& new_param, size_t workspace_in_bytes);
};
using LAMB = LAMBUpdate;
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -36,13 +36,13 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
'NCHW44 = 7','NCHW44_DOT = 8',
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
......@@ -96,9 +96,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'),
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
'NCHW44 = 7','NCHW44_DOT = 8',
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'),
......@@ -107,9 +107,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore '
'instructions for 4-bit integers on Nvidia platforms'),
'instructions for 4-bit integers on Nvidia platforms'),
Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')).
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode')
)
......@@ -1038,10 +1038,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'NCHW_NCHW4 = 24',
'NCHW4_NCHW = 25',
'NCHW_NCHW4_WEIGHT = 26',
'NCHW_NCHW64 = 27',
'NCHW64_NCHW = 28',
'NCHW_NHWC = 29',
'NHWC_NCHW = 30',
'NCHW_NCHW64 = 27',
'NCHW64_NCHW = 28',
'NCHW_NHWC = 29',
'NHWC_NCHW = 30',
'NHWCD4I_NHWC = 31',
)
)
......@@ -1264,3 +1264,14 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f').
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
)
(pdef('LAMBUpdate').
add_fields('float32', Doc('beta_1', 'beta_1 paramter of lamb'), '1.f').
add_fields('float32', Doc('beta_2', 'beta_2 paramter of lamb'), '1.f').
add_fields('float32', Doc('step', 'training step'), '1.f').
add_fields('float32', Doc('lr', 'learning rate'), '1.f').
add_fields('float32', Doc('weight_decay', 'weight decay to adjust learning rate'), '1.f').
add_fields('float32', Doc('eps', 'eps to multi'), '1.f').
add_fields('bool', Doc('bias_correction', 'whether correct bias'), 'true').
add_fields('bool', Doc('always_adapt', 'apply adaptive lr to 0.0'), 'false')
)
......@@ -209,6 +209,7 @@ private:
cb(RNN) \
cb(RNNBackward) \
cb(LSTM) \
cb(LAMBUpdate) \
cb(LSTMBackward) \
cb(SoftmaxForward) \
cb(SoftmaxBackward)
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void LAMBUpdate::deduce_layout(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t,
TensorLayout& v_t, TensorLayout& new_param) {
m_t = TensorLayout(m_t_1);
v_t = TensorLayout(v_t_1);
new_param = TensorLayout(lamb_param);
MEGDNN_MARK_USED_VAR(grad);
}
void LAMBUpdate::check_exec(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad,
const TensorLayout& m_t, const TensorLayout& v_t, const TensorLayout& new_param,
size_t workspace_in_bytes) {
auto required_workspace_in_bytes =
get_workspace_in_bytes(m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn
......@@ -127,6 +127,7 @@ DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
DEF(LayerNormForward, 6, true, true);
DEF(LayerNormBackward, 8, true, true);
DEF(LAMBUpdate, 7, true, true);
DEF(DropoutForward, 3, true, true);
DEF(DropoutBackward, 3, true, true);
DEF(RNNCellForward, 7, true, true);
......
......@@ -35,6 +35,7 @@
#include "src/cuda/images2neibs/opr_impl.h"
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h"
#include "src/cuda/indexing_one_hot/opr_impl.h"
#include "src/cuda/lamb/opr_impl.h"
#include "src/cuda/layer_norm/opr_impl.h"
#include "src/cuda/linspace/opr_impl.h"
#include "src/cuda/local/opr_impl.h"
......@@ -210,6 +211,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LAMBUpdate);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward);
......
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include <thrust/transform_reduce.h>
#include <thrust/tuple.h>
#include <cfloat>
#include "megdnn/arch.h"
#include "megdnn/dtype.h"
#include "src/cuda/cuda_shfl_compat.cuh"
#include "src/cuda/lamb/lamb_cuda.cuh"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace lamb {
template <typename T>
struct square {
__host__ __device__ T operator()(const T& x) const { return x * x; }
};
template <typename T, typename T_ACC>
__global__ void update_kernal_1(
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t,
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr,
float weight_decay, float eps, bool bias_correction, bool always_adapt,
size_t total_nr_elem) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1,
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1;
if (idx < total_nr_elem) {
m_t[idx] = beta_1 * m_t_1[idx] + (1 - beta_1) * static_cast<T_ACC>(grad[idx]);
v_t[idx] = beta_2 * v_t_1[idx] +
(1 - beta_2) * std::pow(static_cast<T_ACC>(grad[idx]), 2);
rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps);
if (weight_decay != 0) {
rt[idx] += lamb_param[idx] * weight_decay;
}
}
}
template <typename T, typename T_ACC>
__global__ void update_kernal_2(
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t,
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr,
float weight_decay, float eps, bool bias_correction, bool always_adapt,
size_t total_nr_elem, T_ACC trust_ratio) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1,
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1;
if (idx < total_nr_elem) {
rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps);
if (weight_decay != 0) {
rt[idx] += lamb_param[idx] * weight_decay;
}
new_param[idx] = lamb_param[idx] - lr * trust_ratio * rt[idx];
}
}
template <typename T, typename T_ACC>
void update(
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t,
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr,
float weight_decay, float eps, bool bias_correction, bool always_adapt,
size_t total_nr_elem, cudaStream_t stream) {
size_t NR_BLOCKS = DIVUP(total_nr_elem, NR_THREADS);
update_kernal_1<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>(
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2,
step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem);
after_kernel_launch();
thrust::device_ptr<T_ACC> lamb_param_ptr(lamb_param);
thrust::device_ptr<T_ACC> rt_ptr(rt);
square<T_ACC> unary_op;
thrust::plus<T_ACC> binary_op;
T_ACC p_norm = std::sqrt(thrust::transform_reduce(
lamb_param_ptr, lamb_param_ptr + total_nr_elem, unary_op, 0.f, binary_op));
T_ACC d_norm = std::sqrt(thrust::transform_reduce(
rt_ptr, rt_ptr + total_nr_elem, unary_op, 0.f, binary_op));
T_ACC trust_ratio = 1;
if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) {
trust_ratio = p_norm / d_norm;
}
update_kernal_2<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>(
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2,
step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem,
trust_ratio);
after_kernel_launch();
}
#define INST(T, T_ACC) \
template void update<T, T_ACC>( \
T_ACC*, T_ACC*, T_ACC*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC*, float, float, \
float, float, float, float, bool, bool, size_t, cudaStream_t);
INST(dt_float32, dt_float32)
INST(dt_float16, dt_float32)
INST(dt_bfloat16, dt_float32)
#undef INST
} // namespace lamb
} // namespace cuda
} // namespace megdnn
#pragma once
#include <cuda_runtime_api.h>
namespace megdnn {
namespace cuda {
namespace lamb {
template <typename T, typename T_ACC>
void update(
T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t,
T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr,
float weight_decay, float eps, bool bias_correction, bool always_adapt,
size_t total_nr_elem, cudaStream_t stream);
} // namespace lamb
} // namespace cuda
} // namespace megdnn
#include "src/cuda/lamb/opr_impl.h"
#include "./lamb_cuda.cuh"
#include "src/cuda/utils.h"
#include <cmath>
#include <functional>
#include <numeric>
namespace megdnn {
namespace cuda {
void LAMBUpdateImpl::exec(
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param,
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
_megdnn_tensor_out new_param, _megdnn_workspace workspace) {
auto p = param();
float beta_1 = p.beta_1;
float beta_2 = p.beta_2;
float step = p.step;
float lr = p.lr;
float weight_decay = p.weight_decay;
float eps = p.eps;
bool bias_correction = p.bias_correction;
bool always_adapt = p.always_adapt;
size_t total_elem = lamb_param.layout.total_nr_elems();
auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::lamb;
#define cb(DType) \
if (grad.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
update<T, T_ACC>( \
m_t_1.ptr<T_ACC>(), v_t_1.ptr<T_ACC>(), lamb_param.ptr<T_ACC>(), \
grad.ptr<T>(), m_t.ptr<T_ACC>(), v_t.ptr<T_ACC>(), \
new_param.ptr<T_ACC>(), workspace.ptr<T_ACC>(), beta_1, beta_2, step, \
lr, weight_decay, eps, bias_correction, always_adapt, total_elem, \
stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace cuda
} // namespace megdnn
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
namespace cuda {
class LAMBUpdateImpl final : public LAMBUpdate {
public:
using LAMBUpdate::LAMBUpdate;
void exec(
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1,
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad,
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
_megdnn_tensor_out new_param, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad,
const TensorLayout& m_t, const TensorLayout& v_t,
const TensorLayout& new_param) override {
return m_t.access_bytes();
};
};
} // namespace cuda
} // namespace megdnn
......@@ -37,6 +37,7 @@
#include "src/naive/images2neibs/opr_impl.h"
#include "src/naive/indexing_multi_axis_vec/opr_impl.h"
#include "src/naive/indexing_one_hot/opr_impl.h"
#include "src/naive/lamb/opr_impl.h"
#include "src/naive/layer_norm/opr_impl.h"
#include "src/naive/linspace/opr_impl.h"
#include "src/naive/local/opr_impl.h"
......
#include "src/naive/lamb/opr_impl.h"
#include <cmath>
#include <functional>
#include <numeric>
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
namespace {
using Param = megdnn::LAMBUpdate::Param;
template <typename T, typename T_ACC = float>
void update(
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param,
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
_megdnn_tensor_out new_param, const Param& param) {
float beta_1 = param.beta_1;
float beta_2 = param.beta_2;
float step = param.step;
float lr = param.lr;
float weight_decay = param.weight_decay;
float eps = param.eps;
bool bias_correction = param.bias_correction;
bool always_adapt = param.always_adapt;
size_t total_elem = lamb_param.layout.total_nr_elems();
T_ACC mt, vt, bc_1, bc_2, rt, d_norm = 0;
bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1;
bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1;
for (size_t i = 0; i < total_elem; i++) {
mt = m_t.ptr<T_ACC>()[i] = beta_1 * m_t_1.ptr<T_ACC>()[i] +
(1 - beta_1) * static_cast<T_ACC>(grad.ptr<T>()[i]);
vt = v_t.ptr<T_ACC>()[i] =
beta_2 * v_t_1.ptr<T_ACC>()[i] +
(1 - beta_2) * std::pow(static_cast<T_ACC>(grad.ptr<T>()[i]), 2);
rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps);
if (weight_decay != 0) {
rt += lamb_param.ptr<T_ACC>()[i] * weight_decay;
}
d_norm += rt * rt;
}
d_norm = sqrt(d_norm);
auto get_norm = [=](_megdnn_tensor_in norm) -> T_ACC {
return sqrt(std::accumulate(
norm.ptr<T_ACC>(), norm.ptr<T_ACC>() + total_elem, 0,
[](T_ACC t1, T_ACC t2) -> T_ACC { return t1 + t2 * t2; }));
};
T_ACC p_norm = get_norm(lamb_param), trust_ratio = 1;
if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) {
trust_ratio = p_norm / d_norm;
}
for (size_t i = 0; i < total_elem; i++) {
mt = m_t.ptr<T_ACC>()[i];
vt = v_t.ptr<T_ACC>()[i];
rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps);
if (weight_decay != 0) {
rt += lamb_param.ptr<T_ACC>()[i] * weight_decay;
}
new_param.ptr<T_ACC>()[i] = lamb_param.ptr<T_ACC>()[i] - lr * trust_ratio * rt;
}
}
} // namespace
namespace megdnn {
namespace naive {
void LAMBUpdateImpl::exec(
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param,
_megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
_megdnn_tensor_out new_param, _megdnn_workspace workspace) {
check_exec(
m_t_1.layout, v_t_1.layout, lamb_param.layout, grad.layout, m_t.layout,
v_t.layout, new_param.layout, workspace.size);
#define cb(DType) \
if (grad.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(update<typename DTypeTrait<DType>::ctype>( \
m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace naive
} // namespace megdnn
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
namespace naive {
class LAMBUpdateImpl final : public LAMBUpdate {
public:
using LAMBUpdate::LAMBUpdate;
void exec(
_megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1,
_megdnn_tensor_in lamb_param, _megdnn_tensor_in grad,
_megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
_megdnn_tensor_out new_param, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& m_t_1, const TensorLayout& v_t_1,
const TensorLayout& lamb_param, const TensorLayout& grad,
const TensorLayout& m_t, const TensorLayout& v_t,
const TensorLayout& new_param) override {
MEGDNN_MARK_USED_VAR(m_t_1);
MEGDNN_MARK_USED_VAR(v_t_1);
MEGDNN_MARK_USED_VAR(lamb_param);
MEGDNN_MARK_USED_VAR(grad);
MEGDNN_MARK_USED_VAR(m_t);
MEGDNN_MARK_USED_VAR(v_t);
MEGDNN_MARK_USED_VAR(new_param);
return 0;
};
};
} // namespace naive
} // namespace megdnn
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
namespace megdnn {
namespace test {
namespace lamb {
struct TestArg {
param::LAMBUpdate param;
TensorShape src;
TestArg(param::LAMBUpdate param, TensorShape src) : param(param), src(src) {}
};
inline std::vector<TestArg> get_args() {
std::vector<TestArg> args;
param::LAMBUpdate cur_param;
cur_param.beta_1 = 0.9;
cur_param.beta_2 = 0.999;
cur_param.eps = 1e-8;
cur_param.weight_decay = 0;
cur_param.lr = 6.25e-5;
cur_param.bias_correction = true;
cur_param.always_adapt = false;
args.emplace_back(
cur_param, TensorShape{
1280,
});
args.emplace_back(cur_param, TensorShape{1280, 1280});
args.emplace_back(cur_param, TensorShape{1280, 3, 224, 224});
return args;
}
} // namespace lamb
} // namespace test
} // namespace megdnn
#include "test/cuda/fixture.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, LAMBUpdate) {
LAMBUpdate::Param param;
param.beta_1 = 0.9;
param.beta_2 = 0.999;
param.eps = 1e-5;
param.weight_decay = 0.4;
param.lr = 1e-3;
param.step = 1;
param.bias_correction = true;
param.always_adapt = false;
Checker<LAMBUpdate> checker(handle_cuda());
checker.set_epsilon(1e-3);
UniformFloatRNG rng0(0, 1);
auto run = [&](DType d) {
checker.set_param(param)
.set_rng(0, &rng0)
.set_rng(1, &rng0)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, d)
.set_dtype(4, dtype::Float32())
.set_dtype(5, dtype::Float32())
.set_dtype(6, dtype::Float32())
.execs({{2}, {2}, {2}, {2}, {}, {}, {}});
};
run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
}
} // namespace test
} // namespace megdnn
#include "test/common/lamb.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
using namespace megdnn;
using namespace test;
TEST_F(NAIVE, LAMBUpdate) {
Checker<LAMBUpdate> checker(handle(), false);
LAMBUpdate::Param param;
param.beta_1 = 0;
param.beta_2 = 0;
param.eps = 0;
param.weight_decay = 0;
param.lr = 1;
param.step = 1;
param.bias_correction = true;
param.always_adapt = false;
TensorND m_t_1 = TensorValue({2}, dtype::Float32(), {1, 1});
TensorND v_t_1 = TensorValue({2}, dtype::Float32(), {1, 1});
TensorND param_lamb = TensorValue({2}, dtype::Float32(), {1, 1});
TensorND grad = TensorValue({2}, dtype::Float16(), {1, 1});
TensorND m_t = TensorValue({2}, dtype::Float32(), {1, 1});
TensorND v_t = TensorValue({2}, dtype::Float32(), {1, 1});
TensorND new_param = TensorValue({2}, dtype::Float32(), {0, 0});
checker.set_param(param).exect(
Testcase{m_t_1, v_t_1, param_lamb, grad, {}, {}, {}},
Testcase{{}, {}, {}, {}, m_t, v_t, new_param});
}
......@@ -4,6 +4,7 @@ from .adagrad import Adagrad
from .adam import Adam
from .adamw import AdamW
from .clip_grad import *
from .lamb import LAMB, LAMBFp16
from .lr_scheduler import LRScheduler
from .multi_step_lr import MultiStepLR
from .optimizer import Optimizer
......
# Copyright (c) 2020 Ross Wightman
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""LAMB optimizer
References: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
"""
import os
from typing import Iterable, Tuple, Union
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops.builtin import LAMBUpdate
from .. import Parameter, tensor
from ..functional import sum
from ..functional.inplace import _inplace_add_
from .optimizer import Optimizer
class LAMB(Optimizer):
r"""Implements LAMB algorithm.
LAMB is proposed in `"Large Batch Optimization for Deep Learning: Training BERT in 76 minutes"
<https://arxiv.org/abs/1904.00962>`_.
Args:
params: iterable of parameters to optimize or dicts defining parameter groups.
lr: learning rate.
betas: coefficients used for computing running averages of gradient and its square.
Default: ``(0.9, 0.999)``
eps: term added to the denominator to improve numerical stability. Default: ``1e-8``
bias_correction: enables bias correction by ``1 - beta ** step``. Default: ``True``
weight_decay: weight decay (L2 penalty). Default: ``0.0``
always_adapt: apply adaptive lr to ``0.0`` weight decay parameter. Default: ``False``
"""
def __init__(
self,
params: Union[Iterable[Parameter], dict],
lr: float,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
bias_correction: bool = True,
weight_decay: float = 0.0,
always_adapt: bool = False,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults)
self.bias_correction = bias_correction
self.always_adapt = always_adapt
self._disable_type_convert = True
def _create_state(self, param_group):
for param in param_group["params"]:
self._add_state(param, "exp_avg")
self._add_state(param, "exp_avg_sq")
self._add_state(param, "step", initializer=0.0, dtype="float32")
def _updates(self, param_group):
lr = param_group["lr"]
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
beta0, beta1 = param_group["betas"]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
c1 = tensor(1.0)
for param in param_group["params"]:
if param.grad is None:
continue
grad = param.grad
states = self._state[param]
step, exp_avg, exp_avg_sq = (
states["step"],
states["exp_avg"],
states["exp_avg_sq"],
)
step += c1
op = LAMBUpdate(
beta0,
beta1,
int(step),
lr,
weight_decay,
eps,
self.bias_correction,
self.always_adapt,
)
new_exp_avg, new_exp_avg_sq, new_param = apply(
op, exp_avg, exp_avg_sq, param, grad
)
param._reset(new_param)
exp_avg._reset(new_exp_avg)
exp_avg_sq._reset(new_exp_avg_sq)
class LAMBFp16(LAMB):
def _create_state(self, param_group):
for param in param_group["params"]:
self._add_state(param, "exp_avg", dtype="float32")
self._add_state(param, "exp_avg_sq", dtype="float32")
self._add_state(param, "step", initializer=0.0, dtype="float32")
self._state[param]["param_fp32"] = param.astype("float32")
def _updates(self, param_group):
lr = param_group["lr"]
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
beta0, beta1 = param_group["betas"]
c1 = tensor(1.0)
for param in param_group["params"]:
if param.grad is None:
continue
grad = param.grad
states = self._state[param]
step, exp_avg, exp_avg_sq = (
states["step"],
states["exp_avg"],
states["exp_avg_sq"],
)
step += c1
fp32_param = states["param_fp32"]
op = LAMBUpdate(
beta0,
beta1,
step,
lr,
weight_decay,
eps,
self.bias_correction,
self.always_adapt,
)
new_exp_avg, new_exp_avg_sq, new_param = apply(
op, exp_avg, exp_avg_sq, fp32_param, grad
)
fp32_param._reset(new_param)
param._reset(new_param.astype("float16"))
exp_avg._reset(new_exp_avg)
exp_avg_sq._reset(new_exp_avg_sq)
import numpy as np
import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops.builtin import LAMBUpdate
def lamb_update(
param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt
):
lr = param_group["lr"]
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
beta0, beta1 = param_group["betas"]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr, _neg_lr = map(tensor, (lr, -lr))
_weight_decay = tensor(weight_decay)
_eps = tensor(eps)
_beta0, _beta1 = map(tensor, (beta0, beta1))
c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0))
def norm(vec):
return sum(vec * vec) ** c05
p_norm = norm(param.flatten())
# step = step + c1
step += c1
# exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
exp_avg *= _beta0
exp_avg += grad * (c1 - _beta0)
# exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
exp_avg_sq *= _beta1
exp_avg_sq += (c1 - _beta1) * (grad * grad)
bias_correction1 = c1 - _beta0 ** step if bias_correction else c1
bias_correction2 = c1 - _beta1 ** step if bias_correction else c1
delta = (exp_avg / bias_correction1) / (
(exp_avg_sq / bias_correction2) ** c05 + _eps
)
if weight_decay != 0.0:
delta += param * _weight_decay
d_norm = norm(delta.flatten())
trust_ratio = (
p_norm / d_norm
if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0
else c1
)
new_param = param - _lr * trust_ratio * delta
return exp_avg, exp_avg_sq, new_param
def test_lamb():
op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False)
m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16)
(new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad)
param_group = {
"betas": (0.9, 0.999),
"step": 1,
"lr": 1e-3,
"weight_decay": 0.4,
"eps": 1e-8,
}
gt_m_t, gt_v_t, gt_new_param = lamb_update(
param_group, 1, m_t_1, v_t_1, params, grad, True, False
)
np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2)
np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2)
np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2)
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
namespace lamb {
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
return layout_checker;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
mgb_assert(input_descs.size() == 4, "IndexingOneHot expects 4inputs");
auto comp_node = input_descs[0].comp_node;
auto comp_node1 = input_descs[1].comp_node;
auto comp_node2 = input_descs[2].comp_node;
TensorLayout m_t_1 = input_descs[0].layout, v_t_1 = input_descs[1].layout,
lamb_param = input_descs[2].layout, grad = input_descs[3].layout;
TensorLayout new_param = lamb_param, m_t = m_t_1, v_t = v_t_1;
return {{{m_t, comp_node}, {v_t, comp_node1}, {new_param, comp_node2}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<LAMBUpdate>();
auto&& m_t_1 = inputs[0];
auto&& v_t_1 = inputs[1];
auto&& lamb_param = inputs[2];
auto&& grad = inputs[3];
TensorLayout m_t_1_layout{m_t_1->layout()};
TensorLayout v_t_1_layout{v_t_1->layout()};
TensorLayout lamb_param_layout{lamb_param->layout()};
DeviceTensorND m_t = BlobManager::inst()->alloc_workspace_with_defrag(
m_t_1->comp_node(), m_t_1_layout);
DeviceTensorND v_t = BlobManager::inst()->alloc_workspace_with_defrag(
v_t_1->comp_node(), v_t_1_layout);
DeviceTensorND new_param = BlobManager::inst()->alloc_workspace_with_defrag(
lamb_param->comp_node(), lamb_param_layout);
DnnOprCaller<megdnn::LAMBUpdate> caller{lamb_param->comp_node()};
TensorLayout m_layout(
{caller.op->get_workspace_in_bytes(
m_t_1->layout(), v_t_1->layout(), lamb_param->layout(),
grad->layout(), m_t.layout(), v_t.layout(), new_param.layout())},
dtype::Byte());
auto dnn_workspace = caller.create_workspace(m_layout);
caller.op->param() = op.param();
caller.op->exec(
m_t_1->dev_tensor().as_megdnn(), v_t_1->dev_tensor().as_megdnn(),
lamb_param->dev_tensor().as_megdnn(), grad->dev_tensor().as_megdnn(),
m_t.as_megdnn(), v_t.as_megdnn(), new_param.as_megdnn(), dnn_workspace);
return {Tensor::make(m_t), Tensor::make(v_t), Tensor::make(new_param)};
}
OP_TRAIT_REG(LAMBUpdate, LAMBUpdate)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace lamb
} // namespace
} // namespace imperative
} // namespace mgb
......@@ -477,6 +477,9 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>;
def LRN: MgbHashableOp<"LRN", [LRNParam]>;
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;
def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>;
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>;
def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册