From d16121533295c04e407c6e25dc0a9aaf3079fe2d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 27 Dec 2018 13:37:29 +0800 Subject: [PATCH] optimize adam multi thread --- paddle/fluid/operators/optimizers/adam_op.h | 13 ++++++++++++- python/paddle/fluid/tests/unittests/test_adam_op.py | 10 +++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 6ff2a2bb6..f907522d5 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -490,9 +490,17 @@ class AdamOpKernel : public framework::OpKernel { << FLAGS_inner_op_parallelism << " min_param_size_to_use_multithread=" << FLAGS_min_param_size_to_use_multithread; + PADDLE_ENFORCE_LE( + FLAGS_inner_op_parallelism, 8, + "FLAGS_inner_op_parallelism should not be larger then 8"); auto& grad_rows = grad_merge.rows(); std::unordered_map row_id_to_grad_row_offset; size_t param_row_count = param.numel() / row_numel; + if (param_row_count < 1000) { + LOG(WARNING) << "param_row_count should be larger then 1000 to use " + "multi thread, currently " + << param_row_count; + } for (size_t i = 0; i < param_row_count; ++i) { row_id_to_grad_row_offset[i] = -1; } @@ -501,10 +509,13 @@ class AdamOpKernel : public framework::OpKernel { } std::vector> fs; int64_t line_in_each_thread = - param_row_count / FLAGS_inner_op_parallelism; + param_row_count / FLAGS_inner_op_parallelism + 1; for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { int64_t start = i * line_in_each_thread; int64_t end = (i + 1) * line_in_each_thread; + if (start >= param_row_count) { + break; + } if (end > param_row_count) { end = param_row_count; } diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index ff7fc5100..463a0655a 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -253,11 +253,11 @@ class TestSparseAdamOp(unittest.TestCase): row_numel = 12 self.row_numel = row_numel self.dense_inputs = { - "Param": np.full((height, row_numel), 5.0).astype("float32"), - "Moment1": np.full((height, row_numel), 5.0).astype("float32"), - "Moment2": np.full((height, row_numel), 5.0).astype("float32"), - 'Beta1Pow': np.array([beta1**10]).astype("float32"), - 'Beta2Pow': np.array([beta2**10]).astype("float32"), + "Param": np.full((height, row_numel), 1.0).astype("float32"), + "Moment1": np.full((height, row_numel), 1.0).astype("float32"), + "Moment2": np.full((height, row_numel), 1.0).astype("float32"), + 'Beta1Pow': np.array([beta1**3]).astype("float32"), + 'Beta2Pow': np.array([beta2**3]).astype("float32"), "LearningRate": np.full((1), 2.0).astype("float32") } self.init_output = np.full((height, row_numel), 0.0).astype("float32") -- GitLab