未验证 提交 b14d4cdd 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #14890 from jacquesqiao/multithread-sparse-adam

adam support multithread
...@@ -35,6 +35,7 @@ DECLARE_bool(benchmark); ...@@ -35,6 +35,7 @@ DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false, DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be " "Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely."); "extremely slow so please use this flag wisely.");
DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -34,6 +34,8 @@ limitations under the License. */ ...@@ -34,6 +34,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
DECLARE_int32(inner_op_parallelism);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -114,6 +114,13 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -114,6 +114,13 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) " "(bool, default false) "
"only update the parameter that has gradient in sparse update") "only update the parameter that has gradient in sparse update")
.SetDefault(false); .SetDefault(false);
AddAttr<int64_t>("min_row_size_to_use_multithread",
"(int64_t, default 0) "
"when not zero, if param row size is larger then "
"min_row_size_to_use_multithread and "
"inner_op_parallelism is larger then 0, sparse update "
"will run in multithread mode")
.SetDefault(1000);
AddComment(R"DOC( AddComment(R"DOC(
Adam Optimizer. Adam Optimizer.
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <Eigen/Dense> #include <Eigen/Dense>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -353,6 +354,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -353,6 +354,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref; using paddle::operators::detail::Ref;
int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode"); bool lazy_mode = ctx.Attr<bool>("lazy_mode");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2")); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
...@@ -473,8 +476,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -473,8 +476,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode); grad_merge.rows().size(), lazy_mode);
if (lazy_mode) { if (lazy_mode) {
VLOG(3) << "run cpu lazy mode";
size_t row_count = grad_merge.rows().size(); size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows()); std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t row_index = 0; row_index < row_count; ++row_index) {
...@@ -483,6 +486,62 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -483,6 +486,62 @@ class AdamOpKernel : public framework::OpKernel<T> {
functor.adam_update(i, grad_data[row_index * row_numel + offset]); functor.adam_update(i, grad_data[row_index * row_numel + offset]);
} }
} }
} else if (FLAGS_inner_op_parallelism > 1 &&
min_row_size_to_use_multithread > 0 &&
param.dims()[0] > min_row_size_to_use_multithread) {
VLOG(3) << "use multi thread, inner_op_parallelism="
<< FLAGS_inner_op_parallelism
<< " min_row_size_to_use_multithread="
<< min_row_size_to_use_multithread;
if (FLAGS_inner_op_parallelism > 10) {
VLOG(1) << "FLAGS_inner_op_parallelism "
<< FLAGS_inner_op_parallelism << " is two large!";
}
auto& grad_rows = grad_merge.rows();
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
size_t param_row_count = param.numel() / row_numel;
if (param_row_count < 1000) {
VLOG(1) << "param_row_count should be larger then 1000 to use "
"multi thread, currently "
<< param_row_count;
}
for (size_t i = 0; i < grad_rows.size(); ++i) {
row_id_to_grad_row_offset[grad_rows[i]] = i;
}
std::vector<std::future<void>> fs;
int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism + 1;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread;
if (start >= param_row_count) {
break;
}
if (end > param_row_count) {
end = param_row_count;
}
fs.push_back(
framework::Async([&functor, &row_id_to_grad_row_offset,
&grad_data, row_numel, start, end]() {
for (int64_t row_id = start; row_id < end; ++row_id) {
auto iter = row_id_to_grad_row_offset.find(row_id);
if (iter != row_id_to_grad_row_offset.end()) {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(
row_id * row_numel + row_offset,
grad_data[iter->second * row_numel + row_offset]);
}
} else {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(row_id * row_numel + row_offset, 0);
}
}
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} else { } else {
functor(param.numel()); functor(param.numel());
} }
......
...@@ -129,7 +129,7 @@ def __bootstrap__(): ...@@ -129,7 +129,7 @@ def __bootstrap__():
'eager_delete_tensor_gb', 'fast_eager_deletion_mode', 'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
'allocator_strategy', 'reader_queue_speed_test_mode', 'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir', 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
'enable_parallel_graph' 'inner_op_parallelism', 'enable_parallel_graph'
] ]
if 'Darwin' not in sysstr: if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory') read_env_flags.append('use_pinned_memory')
......
...@@ -822,7 +822,8 @@ class AdamOptimizer(Optimizer): ...@@ -822,7 +822,8 @@ class AdamOptimizer(Optimizer):
"beta1": self._beta1, "beta1": self._beta1,
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"lazy_mode": self._lazy_mode "lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000
}, },
stop_gradient=True) stop_gradient=True)
......
...@@ -87,6 +87,7 @@ list(REMOVE_ITEM TEST_OPS test_nearest_interp_op) ...@@ -87,6 +87,7 @@ list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL) py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
......
...@@ -261,7 +261,12 @@ class TestSparseAdamOp(unittest.TestCase): ...@@ -261,7 +261,12 @@ class TestSparseAdamOp(unittest.TestCase):
"LearningRate": np.full((1), 2.0).astype("float32") "LearningRate": np.full((1), 2.0).astype("float32")
} }
self.init_output = np.full((height, row_numel), 0.0).astype("float32") self.init_output = np.full((height, row_numel), 0.0).astype("float32")
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} self.attrs = {
'epsilon': epsilon,
'beta1': beta1,
'beta2': beta2,
'min_row_size_to_use_multithread': 2
}
grad_selected_rows = scope.var('Grad').get_selected_rows() grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height) grad_selected_rows.set_height(height)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册