提交 a9f5f822 编写于 作者: D dzhwinter

use binary search. test=develop

上级 38612695
......@@ -74,9 +74,13 @@ class MomentumOpInferVarType : public framework::VarTypeInference {
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
} else {
PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
}
}
}
......@@ -135,5 +139,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::MomentumOpInferVarType);
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
ops::MomentumOpKernel<double>);
REGISTER_OP_CPU_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,125 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, bool use_nesterov, T* p_out,
T* v_out) {
T lr = learning_rate[0];
if (use_nesterov) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T g_val = g[i];
T v_new = v[i] * mu + g_val;
v_out[i] = v_new;
p_out[i] = p[i] - (g_val + v_new * mu) * lr;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T v_new = v[i] * mu + g[i];
v_out[i] = v_new;
p_out[i] = p[i] - lr * v_new;
}
}
}
template <typename T>
__global__ void SparseMomentumKernel(const T* p, const T* g, const T* v,
const T* lr, const T mu,
const int64_t* grad_rows,
const size_t grad_row_numel,
const size_t grad_row_size,
const T use_nesterov, T* p_out, T* v_out) {
for (int i = blockIdx.x; i < grad_row_size; i += gridDim.x) {
for (int j = threadIdx.x; j < grad_row_numel; j += blockDim.x) {
size_t p_i = grad_rows[i] * grad_row_numel + j;
size_t g_i = i * grad_row_numel + j;
v_out[g_i] = v[g_i] * mu + g[g_i];
if (use_nesterov) {
p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0];
} else {
p_out[p_i] = p[p_i] - v_out[g_i] * lr[0];
}
}
}
}
template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* velocity_var = ctx.InputVar("Velocity");
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
PADDLE_ENFORCE(velocity_var->IsType<framework::LoDTensor>(),
"Unmatched Type of Param and Grad");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
const int kThreadPerBlock = 256;
int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock;
MomentumKernel<
T><<<grid, kThreadPerBlock, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
} else if (grad_var->IsType<framework::SelectedRows>()) {
// sparse update embedding with selectedrows
PADDLE_ENFORCE(velocity_var->IsType<framework::SelectedRows>(),
"Unmatched Type of Param and Grad");
auto velocity = ctx.Input<framework::SelectedRows>("Velocity");
auto grad = ctx.Input<framework::SelectedRows>("Grad");
auto velocity_out = ctx.Output<framework::SelectedRows>("VelocityOut");
// sparse update maybe empty.
if (grad->rows().size() == 0) {
return;
}
PADDLE_ENFORCE(grad->height() == velocity->height(),
"Unmatched gradient and velocity.");
auto* p_out = param_out->mutable_data<T>(ctx.GetPlace());
auto* v_out =
velocity_out->mutable_value()->mutable_data<T>(ctx.GetPlace());
auto* lr = learning_rate->data<T>();
auto* p = param->data<T>();
auto* g = grad->value().data<T>();
auto* v = velocity->value().data<T>();
size_t grad_row_numel = grad->value().numel() / grad->rows().size();
size_t grad_row_size = grad->rows().size();
framework::Vector<int64_t> rows(grad->rows());
const int kThreadPerBlock = 256;
int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock;
SparseMomentumKernel<
T><<<grid, kThreadPerBlock, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, rows.CUDAData(ctx.GetPlace()), grad_row_numel,
grad->rows().size(), use_nesterov, p_out, v_out);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
ops::MomentumOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -15,30 +15,45 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;
template <typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
class CPUDenseMomentumFunctor {
private:
const Tensor* param;
const Tensor* grad;
const Tensor* velocity;
const Tensor* learning_rate;
const T mu;
const T use_nesterov;
Tensor* param_out;
Tensor* velocity_out;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad,
const Tensor* velocity, const Tensor* learning_rate,
const T mu, const bool use_nesterov,
Tensor* param_out, Tensor* velocity_out)
: param(param),
grad(grad),
velocity(velocity),
learning_rate(learning_rate),
mu(mu),
use_nesterov(use_nesterov),
param_out(param_out),
velocity_out(velocity_out) {}
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* velocity_var = ctx.InputVar("Velocity");
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
PADDLE_ENFORCE(velocity_var->IsType<framework::LoDTensor>(),
"Unmatched Type of Param and Grad");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
inline void operator()() {
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
......@@ -53,41 +68,303 @@ class MomentumOpKernel : public framework::OpKernel<T> {
} else {
p_out = p - lr[0] * v_out;
}
}
};
template <typename T, typename UpdateMethod>
class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T>
class DenseMomentumFunctor<T, UseNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
T* v_out_;
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(learning_rate),
mu_(mu),
num_(num),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T>
class DenseMomentumFunctor<T, NoNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
T* v_out_;
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(learning_rate),
mu_(mu),
num_(num),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - lr * v_out;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
// TODO(dzh): enhance speed use eigen
// template<typename T>
// class CPUSparseMomentumFunctor {
// private:
// const T* p_;
// const T* g_;
// const T* v_;
// const T* lr_;
// const T mu_;
// const bool use_nesterov_;
// const int64_t* rows_;
// const int64_t row_numel_;
// const int64_t row_height_;
// T* p_out_;
// T* v_out_;
// public:
// CPUSparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
// const T mu, const bool use_nesterov, const int64_t* rows, const int64_t
// row_numel, const int64_t row_height, T* p_out, T* v_out) :p_(p), g_(g),
// v_(v), lr_(lr), mu_(mu), rows_(rows), row_numel_(row_numel),
// row_height_(row_height), p_out_(p_out), v_out_(v_out) {}
// inline void operator()() {
// }
// };
template <typename T, typename UpdateMethod>
class SparseMomentumFunctor;
template <typename T>
class SparseMomentumFunctor<T, UseNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* p_out_;
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(lr),
mu_(mu),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T>
class SparseMomentumFunctor<T, NoNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* p_out_;
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(lr),
mu_(mu),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - v_out * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename DeviceContext, typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* velocity = ctx.Input<framework::Tensor>("Velocity");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto grad = ctx.Input<framework::Tensor>("Grad");
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<T> functor(param, grad, velocity, learning_rate,
mu, use_nesterov, param_out,
velocity_out);
functor();
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
} else {
DenseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
// sparse update embedding with selectedrows
PADDLE_ENFORCE(velocity_var->IsType<framework::SelectedRows>(),
"Unmatched Type of Param and Grad");
auto velocity = ctx.Input<framework::SelectedRows>("Velocity");
auto grad = ctx.Input<framework::SelectedRows>("Grad");
auto velocity_out = ctx.Output<framework::SelectedRows>("VelocityOut");
// sparse update maybe empty.
if (grad->rows().size() == 0) {
VLOG(3) << "Grad SelectedRows contains no data!";
return;
}
PADDLE_ENFORCE(grad->height() == velocity->height(),
"Unmatched gradient and velocity.");
auto* p_out = param_out->mutable_data<T>(ctx.GetPlace());
auto* v_out =
velocity_out->mutable_value()->mutable_data<T>(ctx.GetPlace());
auto* lr = learning_rate->data<T>();
auto* p = param->data<T>();
auto* g = grad->value().data<T>();
auto* v = velocity->value().data<T>();
size_t grad_row_numel = grad->value().numel() / grad->rows().size();
for (size_t i = 0; i < grad->rows().size(); ++i) {
size_t grad_row_index = grad->rows()[i];
for (size_t j = 0; j < grad_row_numel; ++j) {
size_t p_i = grad_row_index * grad_row_numel + j;
size_t g_i = i * grad_row_numel + j;
v_out[g_i] = v[g_i] * mu + g[g_i];
if (use_nesterov) {
p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0];
auto* merged_grad = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(ctx.template device_context<DeviceContext>(), *grad,
merged_grad);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
const int64_t* rows = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
p_out[p_i] = p[p_i] - v_out[g_i] * lr[0];
}
rows = merged_grad->rows().data();
}
if (use_nesterov) {
SparseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows,
static_cast<int64_t>(merged_grad->rows().size()),
static_cast<int64_t>(merged_grad->height()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
} else {
SparseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows,
static_cast<int64_t>(merged_grad->rows().size()),
static_cast<int64_t>(merged_grad->height()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
}
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册