未验证 提交 aba1f9b0 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #14891 from velconia/accelerate_adam

Remove BinarySearch from Adam Op (CPU part)
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <set>
#include <unordered_map>
......@@ -252,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out;
(*this)(context, input, &out);
(*this)(context, input, &out, sorted_result);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows* output,
const bool sorted_result = false) {
std::vector<const framework::SelectedRows*> inputs;
inputs.push_back(&input);
(*this)(context, inputs, output);
(*this)(context, inputs, output, sorted_result);
}
void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) {
framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) {
VLOG(3) << "no input! return";
return;
......@@ -301,6 +305,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
......
......@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) {
const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
......@@ -274,7 +275,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows* output,
const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
......@@ -312,7 +314,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) {
framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) {
VLOG(3) << "no input! return";
return;
......
......@@ -81,13 +81,16 @@ struct MergeAdd {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input);
const framework::SelectedRows& input,
const bool sorted_result = false);
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output);
framework::SelectedRows* output,
const bool sorted_result = false);
void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output);
framework::SelectedRows* output,
const bool sorted_result = false);
};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
......
......@@ -157,8 +157,11 @@ struct AdamFunctor<T, CPUAdam> {
}
};
template <typename T, typename Flavour>
struct SparseAdamFunctor;
template <typename T>
struct SparseAdamFunctor {
struct SparseAdamFunctor<T, GPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
......@@ -236,6 +239,106 @@ struct SparseAdamFunctor {
}
};
template <typename T>
struct SparseAdamFunctor<T, CPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const int64_t* rows_;
int64_t row_numel_;
int64_t row_count_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
int64_t row_numel, int64_t row_count, bool lazy_mode)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE void adam_update(size_t i, T g) const {
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
}
inline void operator()(size_t numel) const {
// lr could be reuse
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
size_t row_count = numel / row_numel_;
for (size_t i = 0U, j = 0U; i != row_count; ++i) {
if (i == *(rows_ + j)) {
for (size_t k = 0U; k != row_numel_; ++k) {
T g = grad_[j * row_numel_ + k];
adam_update(i * row_numel_ + k, g);
}
++j;
} else {
for (size_t k = 0U; k != row_numel_; ++k) {
T mom1 = moment1_[i * row_numel_ + k];
T mom2 = moment2_[i * row_numel_ + k];
T p = param_[i * row_numel_ + k];
mom1 = beta1_ * mom1;
mom2 = beta2_ * mom2;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i * row_numel_ + k] = mom1;
moment2_out_[i * row_numel_ + k] = mom2;
param_out_[i * row_numel_ + k] = p;
}
}
}
}
};
template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
......@@ -331,7 +434,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
.Var()
->GetMutable<framework::SelectedRows>();
merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var);
grad_merge_var, true);
grad_merge_ptr = grad_merge_var;
}
......@@ -347,13 +450,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else {
#endif
rows = grad_merge.rows().data();
#if defined(PADDLE_WITH_CUDA)
}
#endif
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor(
if (platform::is_cpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
......@@ -362,8 +465,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
VLOG(3) << "lazy_mode :" << lazy_mode;
if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) {
if (lazy_mode) {
size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) {
......@@ -373,6 +476,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
}
}
} else {
functor(param.numel());
}
} else if (platform::is_gpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册