提交 d1612153 编写于 作者: Q Qiao Longfei

optimize adam multi thread

上级 7a58ad5c
......@@ -490,9 +490,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
<< FLAGS_inner_op_parallelism
<< " min_param_size_to_use_multithread="
<< FLAGS_min_param_size_to_use_multithread;
PADDLE_ENFORCE_LE(
FLAGS_inner_op_parallelism, 8,
"FLAGS_inner_op_parallelism should not be larger then 8");
auto& grad_rows = grad_merge.rows();
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
size_t param_row_count = param.numel() / row_numel;
if (param_row_count < 1000) {
LOG(WARNING) << "param_row_count should be larger then 1000 to use "
"multi thread, currently "
<< param_row_count;
}
for (size_t i = 0; i < param_row_count; ++i) {
row_id_to_grad_row_offset[i] = -1;
}
......@@ -501,10 +509,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
}
std::vector<std::future<void>> fs;
int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism;
param_row_count / FLAGS_inner_op_parallelism + 1;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread;
if (start >= param_row_count) {
break;
}
if (end > param_row_count) {
end = param_row_count;
}
......
......@@ -253,11 +253,11 @@ class TestSparseAdamOp(unittest.TestCase):
row_numel = 12
self.row_numel = row_numel
self.dense_inputs = {
"Param": np.full((height, row_numel), 5.0).astype("float32"),
"Moment1": np.full((height, row_numel), 5.0).astype("float32"),
"Moment2": np.full((height, row_numel), 5.0).astype("float32"),
'Beta1Pow': np.array([beta1**10]).astype("float32"),
'Beta2Pow': np.array([beta2**10]).astype("float32"),
"Param": np.full((height, row_numel), 1.0).astype("float32"),
"Moment1": np.full((height, row_numel), 1.0).astype("float32"),
"Moment2": np.full((height, row_numel), 1.0).astype("float32"),
'Beta1Pow': np.array([beta1**3]).astype("float32"),
'Beta2Pow': np.array([beta2**3]).astype("float32"),
"LearningRate": np.full((1), 2.0).astype("float32")
}
self.init_output = np.full((height, row_numel), 0.0).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册