diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index fad6f80166cee17acec12929b79592159132f7b9..12b916fcebd425bd4a03d920f947829098a924a1 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -45,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel { "Output(VelocityOut) of Momentum should not be null."); auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - "Param and Grad input of MomentumOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Velocity"), - "Param and Velocity of MomentumOp should have the same dimension."); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + } PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, "Learning_rate should be a scalar"); diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 4a74c078e6a6cc10ec103df1dbc9fea52d7d3b8d..6b4d00f56ca06c402c07ecf770a390e88ae3edf1 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/algorithm.h" @@ -303,28 +304,30 @@ class MomentumOpKernel : public framework::OpKernel { auto* merged_grad = const_cast(ctx.scope()) .Var() ->GetMutable(); - math::scatter::MergeAdd merge_func; merge_func(ctx.template device_context(), *grad, merged_grad); - platform::ForRange for_range( - static_cast(ctx.device_context()), - param->numel()); - const int64_t* rows = nullptr; +#ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(ctx.GetPlace())) { rows = merged_grad->rows().CUDAData(ctx.GetPlace()); } else { +#endif rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA } - +#endif + int64_t row_numel = + merged_grad->value().numel() / merged_grad->rows().size(); + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); if (use_nesterov) { SparseMomentumFunctor functor( param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), mu, rows, + velocity->data(), learning_rate->data(), mu, rows, row_numel, static_cast(merged_grad->rows().size()), - static_cast(merged_grad->height()), param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); @@ -332,9 +335,8 @@ class MomentumOpKernel : public framework::OpKernel { } else { SparseMomentumFunctor functor( param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), mu, rows, + velocity->data(), learning_rate->data(), mu, rows, row_numel, static_cast(merged_grad->rows().size()), - static_cast(merged_grad->height()), param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 9bbffaa7ebbad39d203afe99dbc347237b4f7485..a3d89610b40ff9bd5002e843f8667ada87e67981 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -121,22 +121,13 @@ class TestSparseMomentumOp(unittest.TestCase): grad_tensor = grad_selected_rows.get_tensor() grad_tensor.set(grad_np_array, place) - velocity_selected_rows = scope.var('Velocity').get_selected_rows() - velocity_selected_rows.set_height(height) - velocity_selected_rows.set_rows(rows) - velocity_np_array = np.ones((len(rows), row_numel)).astype("float32") - velocity_np_array[0, 0] = 2.0 - velocity_np_array[2, 8] = 2.0 - velocity_tensor = velocity_selected_rows.get_tensor() - velocity_tensor.set(velocity_np_array, place) - velocity_out_selected_rows = scope.var('VelocityOut').get_selected_rows( - ) - velocity_out_selected_rows.set_height(height) - velocity_out_selected_rows.set_rows(rows) - velocity_out_np_array = np.full((len(rows), row_numel), + velocity = scope.var('Velocity').get_tensor() + velocity_np_array = np.ones((height, row_numel)).astype("float32") + velocity.set(velocity_np_array, place) + velocity_out = scope.var('VelocityOut').get_tensor() + velocity_out_np_array = np.full((height, row_numel), 0.0).astype("float32") - velocity_out_tensor = velocity_out_selected_rows.get_tensor() - velocity_out_tensor.set(velocity_out_np_array, place) + velocity_out.set(velocity_out_np_array, place) # create and initialize LeraningRate Variable lr = scope.var('LearningRate').get_tensor() @@ -158,19 +149,22 @@ class TestSparseMomentumOp(unittest.TestCase): # get and compare result param_out_np_array = np.array(param_out) - velocity_out_np_array = np.array(velocity_out_tensor) + velocity_out_np_array = np.array(velocity_out) # TODO(dzh): add a more suitable general numpy interface # for sparse update. - _velocity_out = mu * velocity_np_array + grad_np_array - _param = param_array[rows] + _grad_np_array = np.full((height, row_numel), 0.0).astype("float32") + for i in range(len(rows)): + _grad_np_array[rows[i]] = grad_np_array[i] + _velocity_out = mu * velocity_np_array + _grad_np_array + _param = param_array if use_nesterov: - _param_out = _param - grad_np_array * lr_array - \ - _velocity_out * mu * lr_array + _param_out = _param - (_grad_np_array + _velocity_out * mu + ) * lr_array else: - _param_out = _param - lr * _velocity_out - self.assertTrue((_param_out == param_out_np_array[rows]).all()) + _param_out = _param - lr_array * _velocity_out self.assertTrue((_velocity_out == velocity_out_np_array).all()) + self.assertTrue((_param_out == param_out_np_array).all()) def init_kernel(self): pass