提交 1fdf8853 编写于 作者: Y Yang Yu

Optimize adam_op

上级 39ef5736
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <math.h> // for sqrt in CPU and CUDA #include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h" #include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/transform.h" #include "paddle/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,10 +36,12 @@ struct AdamFunctor { ...@@ -36,10 +36,12 @@ struct AdamFunctor {
const T* lr_; const T* lr_;
const T* grad_; const T* grad_;
const T* param_; const T* param_;
T* param_out_;
AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* lr, const T* grad, const T* param) T* mom2_out, const T* lr, const T* grad, const T* param,
T* param_out)
: beta1_(beta1), : beta1_(beta1),
beta2_(beta2), beta2_(beta2),
epsilon_(epsilon), epsilon_(epsilon),
...@@ -51,11 +53,10 @@ struct AdamFunctor { ...@@ -51,11 +53,10 @@ struct AdamFunctor {
moment2_out_(mom2_out), moment2_out_(mom2_out),
lr_(lr), lr_(lr),
grad_(grad), grad_(grad),
param_(param) {} param_(param),
param_out_(param_out) {}
// From param[i] --> param_out[i]; inline HOSTDEVICE void operator()(size_t i) const {
inline HOSTDEVICE T operator()(const T& p) const {
size_t i = &p - param_;
// Merge all memory access together. // Merge all memory access together.
T g = grad_[i]; T g = grad_[i];
T mom1 = moment1_[i]; T mom1 = moment1_[i];
...@@ -63,17 +64,18 @@ struct AdamFunctor { ...@@ -63,17 +64,18 @@ struct AdamFunctor {
T lr = *lr_; T lr = *lr_;
T beta1_pow = *beta1_pow_; T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_; T beta2_pow = *beta2_pow_;
T p = param_[i];
// Calculation // Calculation
lr = lr * sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
T new_p = p - lr * (mom1 / (sqrt(mom2) + epsilon_)); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory // Write back to global memory
moment1_out_[i] = mom1; moment1_out_[i] = mom1;
moment2_out_[i] = mom2; moment2_out_[i] = mom2;
return new_p; param_out_[i] = p;
} }
}; };
...@@ -113,13 +115,11 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -113,13 +115,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
mom2.template data<T>(), mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(), lr.template data<T>(), grad.template data<T>(),
param.template data<T>()); param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
const T* in_ptr = param.template data<T>(); platform::ForRange<DeviceContext> for_range(
T* out_ptr = param_out.template mutable_data<T>(ctx.GetPlace()); static_cast<const DeviceContext&>(ctx.device_context()), param.numel());
platform::Transform<DeviceContext> trans; for_range(functor);
trans(static_cast<const DeviceContext&>(ctx.device_context()), in_ptr,
in_ptr + param_out.numel(), out_ptr, functor);
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
template <typename DeviceContext>
struct ForRange {
ForRange(const DeviceContext& dev_ctx, size_t limit);
template <typename Function>
void operator()(Function func) const;
};
template <>
struct ForRange<CPUDeviceContext> {
ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}
template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}
size_t limit_;
};
#ifdef __NVCC__
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);
func(idx);
}
template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
func(idx);
}
}
template <>
struct ForRange<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}
template <typename Function>
inline void operator()(Function func) const {
constexpr size_t num_threads = 1024;
int block_size = limit_ <= num_threads ? limit_ : num_threads;
int grid_size = (limit_ + num_threads - 1) / num_threads;
if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
func);
} else {
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, limit_);
}
}
const CUDADeviceContext& dev_ctx_;
int limit_;
};
#endif
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册