diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 6ff2a2bb6fcf76e214e232c1fd56535aa2ba2bef..f907522d5af13abcae27dea07914ba1fce5e6a62 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -490,9 +490,17 @@ class AdamOpKernel : public framework::OpKernel { << FLAGS_inner_op_parallelism << " min_param_size_to_use_multithread=" << FLAGS_min_param_size_to_use_multithread; + PADDLE_ENFORCE_LE( + FLAGS_inner_op_parallelism, 8, + "FLAGS_inner_op_parallelism should not be larger then 8"); auto& grad_rows = grad_merge.rows(); std::unordered_map row_id_to_grad_row_offset; size_t param_row_count = param.numel() / row_numel; + if (param_row_count < 1000) { + LOG(WARNING) << "param_row_count should be larger then 1000 to use " + "multi thread, currently " + << param_row_count; + } for (size_t i = 0; i < param_row_count; ++i) { row_id_to_grad_row_offset[i] = -1; } @@ -501,10 +509,13 @@ class AdamOpKernel : public framework::OpKernel { } std::vector> fs; int64_t line_in_each_thread = - param_row_count / FLAGS_inner_op_parallelism; + param_row_count / FLAGS_inner_op_parallelism + 1; for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { int64_t start = i * line_in_each_thread; int64_t end = (i + 1) * line_in_each_thread; + if (start >= param_row_count) { + break; + } if (end > param_row_count) { end = param_row_count; } diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index ff7fc5100ebaf12655d5963c600bbd5058720349..463a0655a8a3a311fb14b361416b7ef1cd4c7a70 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -253,11 +253,11 @@ class TestSparseAdamOp(unittest.TestCase): row_numel = 12 self.row_numel = row_numel self.dense_inputs = { - "Param": np.full((height, row_numel), 5.0).astype("float32"), - "Moment1": np.full((height, row_numel), 5.0).astype("float32"), - "Moment2": np.full((height, row_numel), 5.0).astype("float32"), - 'Beta1Pow': np.array([beta1**10]).astype("float32"), - 'Beta2Pow': np.array([beta2**10]).astype("float32"), + "Param": np.full((height, row_numel), 1.0).astype("float32"), + "Moment1": np.full((height, row_numel), 1.0).astype("float32"), + "Moment2": np.full((height, row_numel), 1.0).astype("float32"), + 'Beta1Pow': np.array([beta1**3]).astype("float32"), + 'Beta2Pow': np.array([beta2**3]).astype("float32"), "LearningRate": np.full((1), 2.0).astype("float32") } self.init_output = np.full((height, row_numel), 0.0).astype("float32")