未验证 提交 7f1e3126 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #13456 from sneaxiy/refine_sparse_adam

Fix sparse Adam and Gradient clip of SelectedRows
...@@ -174,12 +174,13 @@ struct SparseAdamFunctor { ...@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
const int64_t* rows_; const int64_t* rows_;
int64_t row_numel_; int64_t row_numel_;
int64_t row_count_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad, const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows, const T* param, T* param_out, const int64_t* rows,
int64_t row_numel) int64_t row_numel, int64_t row_count)
: beta1_(beta1), : beta1_(beta1),
beta2_(beta2), beta2_(beta2),
epsilon_(epsilon), epsilon_(epsilon),
...@@ -194,28 +195,47 @@ struct SparseAdamFunctor { ...@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
param_(param), param_(param),
param_out_(param_out), param_out_(param_out),
rows_(rows), rows_(rows),
row_numel_(row_numel) {} row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
int64_t beg = 0, end = row_count_ - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (rows_[mid] == row)
return mid;
else if (rows_[mid] < row)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
int64_t row = i / row_numel_;
auto row_idx = BinarySearchInRows(row);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_; T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_; T beta2_pow = *beta2_pow_;
for (int64_t j = 0; j < row_numel_; ++j) { T p = param_[i];
T g = grad_[i * row_numel_ + j];
T mom1 = moment1_[rows_[i] * row_numel_ + j]; // Calculation
T mom2 = moment2_[rows_[i] * row_numel_ + j]; lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
T lr = *lr_;
T p = param_[rows_[i] * row_numel_ + j]; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
mom1 = beta1_ * mom1 + (1 - beta1_) * g; // Write back to global memory
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; moment1_out_[i] = mom1;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); moment2_out_[i] = mom2;
param_out_[i] = p;
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
param_out_[rows_[i] * row_numel_ + j] = p;
} // for col id
} }
}; };
...@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
return; return;
} }
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge = auto& grad_merge = *(ctx.scope()
merge_func(ctx.template device_context<DeviceContext>(), grad); .NewScope()
.Var("sparse_adam_grad_merge")
->GetMutable<framework::SelectedRows>());
merge_func(ctx.template device_context<DeviceContext>(), grad,
&grad_merge);
auto& grad_tensor = grad_merge.value(); auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
int64_t* rows = nullptr; int64_t* rows = nullptr;
...@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
mom2.template data<T>(), mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel); param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size());
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
grad_merge.rows().size()); param.numel());
for_range(functor); for_range(functor);
} else { } else {
PADDLE_THROW("Variable type not supported by adam_op"); PADDLE_THROW("Variable type not supported by adam_op");
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min"); auto min = context.Attr<T>("min");
auto* x = context.Input<Tensor>("X"); auto* x_var = context.InputVar("X");
auto* out = context.Output<Tensor>("Out"); if (x_var->IsType<framework::LoDTensor>()) {
T* out_data = out->mutable_data<T>(context.GetPlace()); auto* x = context.Input<framework::LoDTensor>("X");
const T* x_data = x->data<T>(); auto* out = context.Output<framework::LoDTensor>("Out");
int64_t numel = x->numel(); T* out_data = out->mutable_data<T>(context.GetPlace());
Transform<DeviceContext> trans; const T* x_data = x->data<T>();
trans(context.template device_context<DeviceContext>(), x_data, int64_t numel = x->numel();
x_data + numel, out_data, ClipFunctor<T>(min, max)); Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max));
} else if (x_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out");
PADDLE_ENFORCE_NE(x, out,
"Inplace clip is not allowed when x is SelectedRows");
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *x, out);
auto* out_tensor = out->mutable_value();
auto* out_data = out_tensor->data<T>();
int64_t numel = out_tensor->numel();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), out_data,
out_data + numel, out_data, ClipFunctor<T>(min, max));
} else {
PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
}
} }
}; };
...@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> { ...@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min"); auto min = context.Attr<T>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out =
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) { if (d_x != nullptr) {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<framework::LoDTensor>("X");
int64_t numel = d_out->numel(); int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace()); auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>(); const T* d_out_data = d_out->data<T>();
......
...@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
auto input_rows = input.rows(); auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
...@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out_data[out_i * input_width + j] += input_data[i * input_width + j]; out_data[out_i * input_width + j] += input_data[i * input_width + j];
} }
} }
return out;
} }
}; };
......
...@@ -234,7 +234,7 @@ template <typename T, int block_size> ...@@ -234,7 +234,7 @@ template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
T* out, const int64_t* out_rows, T* out, const int64_t* out_rows,
size_t out_rows_size, int64_t row_numel) { size_t out_rows_size, int64_t row_numel) {
const int ty = blockIdx.y; const int ty = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
__shared__ size_t out_idx; __shared__ size_t out_idx;
...@@ -260,6 +260,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -260,6 +260,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context, framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
...@@ -281,16 +289,12 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -281,16 +289,12 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
const int block_size = 256; const int block_size = 256;
dim3 threads(block_size, 1); dim3 threads(block_size, 1);
dim3 grid1(1, input_rows.size()); dim3 grid1(input_rows.size(), 1);
MergeAddKernel< MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
input_data, input_rows.CUDAData(context.GetPlace()), out_data, input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width); out.rows().size(), input_width);
return out;
} }
}; };
......
...@@ -65,6 +65,9 @@ struct MergeAdd { ...@@ -65,6 +65,9 @@ struct MergeAdd {
// the input SelectedRows object. // the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input); const framework::SelectedRows& input);
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册