diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index c8079a99fb8c8e05144c5390794db7757e74f6ae..12b916fcebd425bd4a03d920f947829098a924a1 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(param) of Momentum should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -45,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel { "Output(VelocityOut) of Momentum should not be null."); auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - "Param and Grad input of MomentumOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Velocity"), - "Param and Velocity of MomentumOp should have the same dimension."); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + } PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, "Learning_rate should be a scalar"); @@ -58,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel { ctx->SetOutputDim("VelocityOut", param_dim); } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); + const framework::ExecutionContext& ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; +class MomentumOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto input_var = op_desc.Input("Param")[0]; + for (auto& out_var : op_desc.Output("ParamOut")) { + if (block->FindRecursiveOrCreateVar(input_var).GetType() == + framework::proto::VarType::SELECTED_ROWS) { + block->FindRecursiveOrCreateVar(out_var).SetType( + framework::proto::VarType::SELECTED_ROWS); + } else if (block->FindRecursiveOrCreateVar(input_var).GetType() == + framework::proto::VarType::LOD_TENSOR) { + block->FindRecursiveOrCreateVar(out_var).SetType( + framework::proto::VarType::LOD_TENSOR); + } else { + PADDLE_THROW( + "Only support LodTensor and SelectedRows, Unexpected Input Type."); + } + } + } +}; + class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -115,6 +139,9 @@ $$ } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker); -REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel, - ops::MomentumOpKernel); +REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::MomentumOpInferVarType); +REGISTER_OP_CPU_KERNEL( + momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/momentum_op.cu b/paddle/fluid/operators/momentum_op.cu index 5dc920c70979ad5c3d4fb3acc8d2dacaffe386c0..b68fec34d43f0dee834f1045f192d5c6089d9356 100644 --- a/paddle/fluid/operators/momentum_op.cu +++ b/paddle/fluid/operators/momentum_op.cu @@ -15,76 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/momentum_op.h" -namespace paddle { -namespace operators { - -template -__global__ void MomentumKernel(const T* p, const T* g, const T* v, - const T* learning_rate, const T mu, - const int64_t num, bool use_nesterov, T* p_out, - T* v_out) { - T lr = learning_rate[0]; - if (use_nesterov) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - T g_val = g[i]; - T v_new = v[i] * mu + g_val; - v_out[i] = v_new; - p_out[i] = p[i] - (g_val + v_new * mu) * lr; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - T v_new = v[i] * mu + g[i]; - v_out[i] = v_new; - p_out[i] = p[i] - lr * v_new; - } - } -} - -template -class MomentumOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE(param_var->IsType(), - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.Inputs("Param").front(), param_var->Type().name()); - const auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE(grad_var->IsType(), - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.Inputs("Grad").front(), grad_var->Type().name()); - - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto learning_rate = ctx.Input("LearningRate"); - - T* p_out = param_out->mutable_data(ctx.GetPlace()); - T* v_out = velocity_out->mutable_data(ctx.GetPlace()); - - T mu = static_cast(ctx.Attr("mu")); - bool use_nesterov = ctx.Attr("use_nesterov"); - - auto* p = param->data(); - auto* v = velocity->data(); - auto* g = grad->data(); - auto* lr = learning_rate->data(); - - int block = 512; - int grid = (param->numel() + block - 1) / block; - MomentumKernel<<>>( - p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel, - ops::MomentumOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 40073d21b7186ad319e4988e667750c267581300..6b4d00f56ca06c402c07ecf770a390e88ae3edf1 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -13,35 +13,48 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { -template -class MomentumOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE(param_var->IsType(), - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.Inputs("Param").front(), param_var->Type().name()); - - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto learning_rate = ctx.Input("LearningRate"); +using framework::Tensor; +using framework::SelectedRows; +struct NoNesterov; +struct UseNesterov; - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); +template +class CPUDenseMomentumFunctor { + private: + const Tensor* param; + const Tensor* grad; + const Tensor* velocity; + const Tensor* learning_rate; + const T mu; + const T use_nesterov; + Tensor* param_out; + Tensor* velocity_out; - T mu = static_cast(ctx.Attr("mu")); - bool use_nesterov = ctx.Attr("use_nesterov"); + public: + CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad, + const Tensor* velocity, const Tensor* learning_rate, + const T mu, const bool use_nesterov, + Tensor* param_out, Tensor* velocity_out) + : param(param), + grad(grad), + velocity(velocity), + learning_rate(learning_rate), + mu(mu), + use_nesterov(use_nesterov), + param_out(param_out), + velocity_out(velocity_out) {} + inline void operator()() { auto p_out = framework::EigenVector::Flatten(*param_out); auto v_out = framework::EigenVector::Flatten(*velocity_out); @@ -59,5 +72,283 @@ class MomentumOpKernel : public framework::OpKernel { } }; +template +class DenseMomentumFunctor; + +// NOTE(dzh) for performance. +// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two +// functor. +template +class DenseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t num_; + T* p_out_; + T* v_out_; + + public: + DenseMomentumFunctor(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, const int64_t num, + T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(learning_rate), + mu_(mu), + num_(num), + p_out_(p_out), + v_out_(v_out) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const T p = p_[i]; + const T g = g_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - (g + v_out * mu_) * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class DenseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t num_; + T* p_out_; + T* v_out_; + + public: + DenseMomentumFunctor(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, const int64_t num, + T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(learning_rate), + mu_(mu), + num_(num), + p_out_(p_out), + v_out_(v_out) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const T p = p_[i]; + const T g = g_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - lr * v_out; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class SparseMomentumFunctor; + +template +class SparseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* p_out_; + T* v_out_; + + public: + SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, + const T mu, const int64_t* rows, int64_t row_numel, + int64_t row_height, T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(lr), + mu_(mu), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + p_out_(p_out), + v_out_(v_out) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + math::BinarySearch(rows_, row_height_, i / row_numel_); + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + // put memory access in register + const T p = p_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - (g + v_out * mu_) * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class SparseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* p_out_; + T* v_out_; + + public: + SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, + const T mu, const int64_t* rows, int64_t row_numel, + int64_t row_height, T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(lr), + mu_(mu), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + p_out_(p_out), + v_out_(v_out) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + math::BinarySearch(rows_, row_height_, i / row_numel_); + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + // put memory access in register + const T p = p_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - v_out * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class MomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + T mu = static_cast(ctx.Attr("mu")); + bool use_nesterov = ctx.Attr("use_nesterov"); + + auto learning_rate = ctx.Input("LearningRate"); + auto param = ctx.Input("Param"); + auto param_out = ctx.Output("ParamOut"); + auto* velocity = ctx.Input("Velocity"); + auto velocity_out = ctx.Output("VelocityOut"); + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + + auto* grad_var = ctx.InputVar("Grad"); + if (grad_var->IsType()) { + auto grad = ctx.Input("Grad"); + if (platform::is_cpu_place(ctx.GetPlace())) { + CPUDenseMomentumFunctor functor(param, grad, velocity, learning_rate, + mu, use_nesterov, param_out, + velocity_out); + functor(); + } else if (platform::is_gpu_place(ctx.GetPlace())) { + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + if (use_nesterov) { + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), mu, param->numel(), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + + } else { + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), mu, param->numel(), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + } + } + + } else if (grad_var->IsType()) { + // sparse update embedding with selectedrows + auto grad = ctx.Input("Grad"); + + // sparse update maybe empty. + if (grad->rows().size() == 0) { + VLOG(3) << "Grad SelectedRows contains no data!"; + return; + } + auto* merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + math::scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), *grad, + merged_grad); + + const int64_t* rows = nullptr; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { +#endif + rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + int64_t row_numel = + merged_grad->value().numel() / merged_grad->rows().size(); + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + if (use_nesterov) { + SparseMomentumFunctor functor( + param->data(), merged_grad->value().data(), + velocity->data(), learning_rate->data(), mu, rows, row_numel, + static_cast(merged_grad->rows().size()), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + + } else { + SparseMomentumFunctor functor( + param->data(), merged_grad->value().data(), + velocity->data(), learning_rate->data(), mu, rows, row_numel, + static_cast(merged_grad->rows().size()), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + } + } else { + PADDLE_THROW( + string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows " + "gradient, but the received Variable Type is %s", + grad_var->Type().name())); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 7137fd0fdb7c503492107da684b95989037eb872..a3d89610b40ff9bd5002e843f8667ada87e67981 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator from op_test import OpTest @@ -88,5 +90,97 @@ class TestMomentumOp2(OpTest): self.check_output() +class TestSparseMomentumOp(unittest.TestCase): + def setUp(self): + self.use_nesterov = False + + def check_with_place(self, place): + self.init_kernel() + scope = core.Scope() + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 7] + row_numel = 12 + mu = 1.0 + use_nesterov = self.use_nesterov + + # create and initialize Param Variable + param = scope.var('Param').get_tensor() + param_array = np.full((height, row_numel), 5.0).astype("float32") + param.set(param_array, place) + param_out = scope.var("ParamOut").get_tensor() + param_out_array = np.full((height, row_numel), 0.0).astype("float32") + param_out.set(param_out_array, place) + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + grad_np_array = np.ones((len(rows), row_numel)).astype("float32") + grad_np_array[0, 0] = 2.0 + grad_np_array[2, 8] = 4.0 + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(grad_np_array, place) + + velocity = scope.var('Velocity').get_tensor() + velocity_np_array = np.ones((height, row_numel)).astype("float32") + velocity.set(velocity_np_array, place) + velocity_out = scope.var('VelocityOut').get_tensor() + velocity_out_np_array = np.full((height, row_numel), + 0.0).astype("float32") + velocity_out.set(velocity_out_np_array, place) + + # create and initialize LeraningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and run operator + op = Operator( + "momentum", + Param='Param', + Grad='Grad', + Velocity='Velocity', + ParamOut='ParamOut', + VelocityOut='VelocityOut', + LearningRate='LearningRate', + mu=mu, + use_nesterov=use_nesterov) + op.run(scope, place) + + # get and compare result + param_out_np_array = np.array(param_out) + velocity_out_np_array = np.array(velocity_out) + + # TODO(dzh): add a more suitable general numpy interface + # for sparse update. + _grad_np_array = np.full((height, row_numel), 0.0).astype("float32") + for i in range(len(rows)): + _grad_np_array[rows[i]] = grad_np_array[i] + _velocity_out = mu * velocity_np_array + _grad_np_array + _param = param_array + if use_nesterov: + _param_out = _param - (_grad_np_array + _velocity_out * mu + ) * lr_array + else: + _param_out = _param - lr_array * _velocity_out + self.assertTrue((_velocity_out == velocity_out_np_array).all()) + self.assertTrue((_param_out == param_out_np_array).all()) + + def init_kernel(self): + pass + + def test_sparse_momentum(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + +class TestSparseMomentumOp2(TestSparseMomentumOp): + def init_kernel(self): + self.use_nesterov = True + + if __name__ == "__main__": unittest.main()