提交 7cd97069 编写于 作者: L Li Xinqi 提交者: GitHub

Dev profiling adam (#1592)

* profiling

* all_reduce_* option for performance optimization

* faster adam kernel


Former-commit-id: 5885d1ff7eb09cbd97ca13c22dabe3835af528a6
上级 0e43539e
......@@ -6,26 +6,61 @@ namespace oneflow {
namespace {
template<typename T>
__global__ void UpdateMomentEstimateGpu(int64_t n, bool do_bias_correction, T beta, int32_t p,
const T* model_diff, const T* beta_t, T* moment) {
CUDA_1D_KERNEL_LOOP(i, n) {
// Update biased moment estimate
moment[i] = beta * moment[i] + (1 - beta) * std::pow(model_diff[i], p);
if (do_bias_correction) {
// Correct deviation of moment estimate
moment[i] = moment[i] / (1 - *beta_t);
}
template<int32_t power>
struct PowUtil;
template<>
struct PowUtil<1> final {
template<typename T>
__device__ static T pow(const T x) {
return x;
}
};
template<>
struct PowUtil<2> final {
template<typename T>
__device__ static T pow(const T x) {
return x * x;
}
};
template<bool do_bias_correction, typename T>
__device__ typename std::enable_if<do_bias_correction>::type ScaleMomentum(const T beta_t,
T* moment) {
*moment /= (1 - beta_t);
}
template<bool do_bias_correction, typename T>
__device__ typename std::enable_if<!do_bias_correction>::type ScaleMomentum(const T beta_t,
T* moment) {}
template<int32_t power, bool do_bias_correction, typename T>
__device__ void UpdateMomentEstimate(T beta, const T* model_diff, const T* beta_t, T* moment) {
// Update biased moment estimate
*moment = beta * (*moment) + (1 - beta) * PowUtil<power>::pow(*model_diff);
// Correct deviation of moment estimate
ScaleMomentum<do_bias_correction>(*beta_t, moment);
}
template<typename T>
__device__ void UpdateModel(const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, T epsilon,
T* model_diff, T* model, T* m, T* v) {
*model_diff = *m / (sqrt(*v) + epsilon);
T reg_diff = RegularizeDiff(*model_diff, *batch_instance_num_ptr, l1, l2, *model);
*model = *model - learning_rate * reg_diff;
}
template<bool do_bias_correction, typename T>
__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T learning_rate, T l1,
T l2, T epsilon, T* model_diff, T* model, T* m, T* v) {
T l2, T beta1, T beta2, T epsilon,
const T* beta1_t, const T* beta2_t,
T* model_diff, T* model, T* m, T* v) {
CUDA_1D_KERNEL_LOOP(i, n) {
model_diff[i] = m[i] / (std::sqrt(v[i]) + epsilon);
T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
model[i] = model[i] - learning_rate * reg_diff;
UpdateMomentEstimate<1, do_bias_correction>(beta1, model_diff + i, beta1_t, m + i);
UpdateMomentEstimate<2, do_bias_correction>(beta2, model_diff + i, beta2_t, v + i);
UpdateModel(batch_instance_num_ptr, learning_rate, l1, l2, epsilon, model_diff + i, model + i,
m + i, v + i);
}
}
......@@ -38,16 +73,17 @@ class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final {
T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon,
bool do_bias_correction, int64_t next_model_vid, const T* beta1_t,
const T* beta2_t, T* model_diff, T* model, T* m, T* v) {
// first-order moment
UpdateMomentEstimateGpu<T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, do_bias_correction, beta1, 1, model_diff, beta1_t, m);
// second-order moment
UpdateMomentEstimateGpu<T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, do_bias_correction, beta2, 2, model_diff, beta2_t, v);
UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, batch_instance_num_ptr, learning_rate, l1, l2, epsilon, model_diff, model, m, v);
if (do_bias_correction) {
UpdateModelGpu<true, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, batch_instance_num_ptr, learning_rate, l1, l2, beta1, beta2, epsilon,
beta1_t, beta2_t, model_diff, model, m, v);
} else {
UpdateModelGpu<false, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, batch_instance_num_ptr, learning_rate, l1, l2, beta1, beta2, epsilon,
beta1_t, beta2_t, model_diff, model, m, v);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册