提交 fac87022 编写于 作者: Q Qiao Longfei

adam support multithread

上级 e2130502
......@@ -30,6 +30,8 @@ DECLARE_bool(benchmark);
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op");
DEFINE_int32(min_param_size_to_use_multithread, 0, "");
namespace paddle {
namespace framework {
......
......@@ -34,6 +34,9 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
DECLARE_int32(inner_op_parallelism);
DECLARE_int32(min_param_size_to_use_multithread);
namespace paddle {
namespace framework {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <Eigen/Dense>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
......@@ -352,10 +353,31 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size());
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
int inner_op_parallelism = FLAGS_inner_op_parallelism;
if (inner_op_parallelism > 1 &&
FLAGS_min_param_size_to_use_multithread > 0 &&
param.numel() > FLAGS_min_param_size_to_use_multithread) {
std::vector<std::future<void>> fs;
int64_t block_size = param.numel() / inner_op_parallelism;
for (int i = 0; i < inner_op_parallelism; ++i) {
int64_t start = i * block_size;
int64_t end = (i + 1) * block_size;
if (end > param.numel()) {
end = param.numel();
}
fs.push_back(framework::Async([&functor, start, end]() {
for (int64_t i = start; i < end; ++i) {
functor(i);
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} else {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
}
} else {
PADDLE_THROW("Variable type not supported by adam_op");
}
......
......@@ -128,7 +128,8 @@ def __bootstrap__():
'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size",
'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir', 'pe_profile_fname'
'print_sub_graph_dir', 'pe_profile_fname', 'inner_op_parallelism',
'min_param_size_to_use_multithread'
]
if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册