提交 5bd1e73f 编写于 作者: D dangqingqing

Refine and speedup momentum operator.

上级 8ed8a935
......@@ -101,5 +101,5 @@ $$
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
REGISTER_OP_CPU_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
ops::MomentumOpKernel<double>);
......@@ -12,9 +12,67 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/momentum_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, bool use_nesterov, T* p_out,
T* v_out) {
T lr = learning_rate[0];
if (use_nesterov) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T g_val = g[i];
T v_new = v[i] * mu + g_val;
v_out[i] = v_new;
p_out[i] = p[i] - g_val * lr + v_new * mu * lr;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T v_new = v[i] * mu + g[i];
v_out[i] = v_new;
p_out[i] = p[i] - lr * v_new;
}
}
}
template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
MomentumKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
ops::MomentumOpCUDAKernel<double>);
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -33,7 +33,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
float mu = ctx.Attr<float>("mu");
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
......@@ -42,18 +42,17 @@ class MomentumOpKernel : public framework::OpKernel<T> {
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto* lr = learning_rate->data<T>();
auto place = ctx.GetEigenDevice<Place>();
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
v_out.device(place) = v * mu + g;
if (use_nesterov) {
p_out.device(place) = p - g * lr.broadcast(grad_dsize) +
v_out * mu * lr.broadcast(grad_dsize);
p_out.device(place) = p - (g - v_out * mu) * lr[0];
} else {
p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out;
p_out.device(place) = p - lr[0] * v_out;
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册