提交 b6f61faf 编写于 作者: S sneaxiy

fix adam

上级 2d898491
...@@ -174,12 +174,13 @@ struct SparseAdamFunctor { ...@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
const int64_t* rows_; const int64_t* rows_;
int64_t row_numel_; int64_t row_numel_;
int64_t row_count_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad, const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows, const T* param, T* param_out, const int64_t* rows,
int64_t row_numel) int64_t row_numel, int64_t row_count)
: beta1_(beta1), : beta1_(beta1),
beta2_(beta2), beta2_(beta2),
epsilon_(epsilon), epsilon_(epsilon),
...@@ -194,28 +195,47 @@ struct SparseAdamFunctor { ...@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
param_(param), param_(param),
param_out_(param_out), param_out_(param_out),
rows_(rows), rows_(rows),
row_numel_(row_numel) {} row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
int64_t beg = 0, end = row_count_ - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (rows_[mid] == row)
return mid;
else if (rows_[mid] < row)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
int64_t row = i / row_numel_;
auto row_idx = BinarySearchInRows(row);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_; T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_; T beta2_pow = *beta2_pow_;
for (int64_t j = 0; j < row_numel_; ++j) { T p = param_[i];
T g = grad_[i * row_numel_ + j];
T mom1 = moment1_[rows_[i] * row_numel_ + j];
T mom2 = moment2_[rows_[i] * row_numel_ + j];
T lr = *lr_;
T p = param_[rows_[i] * row_numel_ + j];
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
moment1_out_[rows_[i] * row_numel_ + j] = mom1; // Write back to global memory
moment2_out_[rows_[i] * row_numel_ + j] = mom2; moment1_out_[i] = mom1;
param_out_[rows_[i] * row_numel_ + j] = p; moment2_out_[i] = mom2;
} // for col id param_out_[i] = p;
} }
}; };
...@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
return; return;
} }
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge = auto& grad_merge = *(ctx.scope()
merge_func(ctx.template device_context<DeviceContext>(), grad); .NewScope()
.Var("sparse_adam_grad_merge")
->GetMutable<framework::SelectedRows>());
merge_func(ctx.template device_context<DeviceContext>(), grad,
&grad_merge);
auto& grad_tensor = grad_merge.value(); auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
int64_t* rows = nullptr; int64_t* rows = nullptr;
...@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
mom2.template data<T>(), mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel); param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size());
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
grad_merge.rows().size()); param.numel());
for_range(functor); for_range(functor);
} else { } else {
PADDLE_THROW("Variable type not supported by adam_op"); PADDLE_THROW("Variable type not supported by adam_op");
......
...@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
auto input_rows = input.rows(); auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
...@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out_data[out_i * input_width + j] += input_data[i * input_width + j]; out_data[out_i * input_width + j] += input_data[i * input_width + j];
} }
} }
return out;
} }
}; };
......
...@@ -262,6 +262,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -262,6 +262,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context, framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
...@@ -292,7 +300,6 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -292,7 +300,6 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
input_data, input_rows.CUDAData(context.GetPlace()), out_data, input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width); out.rows().size(), input_width);
return out;
} }
}; };
......
...@@ -65,6 +65,9 @@ struct MergeAdd { ...@@ -65,6 +65,9 @@ struct MergeAdd {
// the input SelectedRows object. // the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input); const framework::SelectedRows& input);
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册