未验证 提交 95e90aa1 编写于 作者: 1 123malin 提交者: GitHub

test=develop, add communicator_is_sgd_optimizer flag (#20677)

* test=develop, communicator_is_sgd_optimizer flags
上级 74a28f5e
...@@ -89,6 +89,8 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -89,6 +89,8 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
VLOG(0) << "communicator_merge_sparse_grad: " VLOG(0) << "communicator_merge_sparse_grad: "
<< FLAGS_communicator_merge_sparse_grad; << FLAGS_communicator_merge_sparse_grad;
VLOG(0) << "communicator_is_sgd_optimizer: "
<< FLAGS_communicator_is_sgd_optimizer;
if (send_varname_to_ctx.size() == 0) { if (send_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be send, will not start send_thread"; VLOG(0) << "nothing need to be send, will not start send_thread";
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -37,6 +38,8 @@ limitations under the License. */ ...@@ -37,6 +38,8 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
...@@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name, ...@@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name,
auto in = EigenVector<float>::Flatten(in_t); auto in = EigenVector<float>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in; result.device(*cpu_ctx.eigen_device()) = result + in;
} }
if (!FLAGS_communicator_is_sgd_optimizer) {
result.device(*cpu_ctx.eigen_device()) = result.device(*cpu_ctx.eigen_device()) =
result / static_cast<float>(vars.size()); result / static_cast<float>(vars.size());
}
} else if (var0->IsType<framework::SelectedRows>()) { } else if (var0->IsType<framework::SelectedRows>()) {
auto& slr0 = var0->Get<framework::SelectedRows>(); auto& slr0 = var0->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>(); auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
...@@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name, ...@@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name,
inputs.push_back(&var->Get<framework::SelectedRows>()); inputs.push_back(&var->Get<framework::SelectedRows>());
} }
auto dev_ctx = paddle::platform::CPUDeviceContext(); auto dev_ctx = paddle::platform::CPUDeviceContext();
if (FLAGS_communicator_is_sgd_optimizer) {
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float> math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
merge_average; merge_average;
merge_average(dev_ctx, inputs, out_slr); merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims(); << " dims: " << slr0.value().dims();
} else { } else {
......
...@@ -42,7 +42,6 @@ TEST(communicator, merge_lod_tensors) { ...@@ -42,7 +42,6 @@ TEST(communicator, merge_lod_tensors) {
} }
out_value += static_cast<float>(i); out_value += static_cast<float>(i);
} }
out_value = out_value / 10.0;
const std::string out_name = "Out"; const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope; std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope()); scope.reset(new framework::Scope());
...@@ -96,7 +95,7 @@ TEST(communicator, merge_selected_rows) { ...@@ -96,7 +95,7 @@ TEST(communicator, merge_selected_rows) {
std::vector<float> out_values; std::vector<float> out_values;
out_values.reserve(10); out_values.reserve(10);
for (auto i = 0; i < 10; ++i) { for (auto i = 0; i < 10; ++i) {
out_values.push_back(static_cast<float>((i * (10 - i)) / 10.0)); out_values.push_back(static_cast<float>(i * (10 - i)));
} }
for (auto i = 0; i < out_slr.rows().size(); ++i) { for (auto i = 0; i < out_slr.rows().size(); ++i) {
ASSERT_EQ(out_slr.rows()[i], i); ASSERT_EQ(out_slr.rows()[i], i);
......
...@@ -139,6 +139,13 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -139,6 +139,13 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows(); auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) {
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader async mode, you need to "
"turn it false.";
}
std::vector<std::vector<size_t>> outs_rows_idx; std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx; std::vector<std::vector<size_t>> outs_dense_idx;
......
...@@ -199,7 +199,9 @@ DEFINE_bool( ...@@ -199,7 +199,9 @@ DEFINE_bool(
*/ */
DEFINE_int32(communicator_max_merge_var_num, 20, DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send"); "max var num to merge and send");
DEFINE_bool(communicator_is_sgd_optimizer, true,
"gradient sent to the server is the sum of the gradients "
"calculated by each thread if optimizer is sgd");
/** /**
* Distributed related FLAG * Distributed related FLAG
* Name: FLAGS_communicator_send_queue_size * Name: FLAGS_communicator_send_queue_size
......
...@@ -206,6 +206,7 @@ def __bootstrap__(): ...@@ -206,6 +206,7 @@ def __bootstrap__():
read_env_flags.append('communicator_fake_rpc') read_env_flags.append('communicator_fake_rpc')
read_env_flags.append('communicator_send_wait_times') read_env_flags.append('communicator_send_wait_times')
read_env_flags.append('communicator_merge_sparse_grad') read_env_flags.append('communicator_merge_sparse_grad')
read_env_flags.append('communicator_is_sgd_optimizer')
if core.is_compiled_with_brpc(): if core.is_compiled_with_brpc():
read_env_flags.append('max_body_size') read_env_flags.append('max_body_size')
#set brpc max body size #set brpc max body size
......
...@@ -113,7 +113,8 @@ class TestDistCTR2x2_ASYNC2(TestDistBase): ...@@ -113,7 +113,8 @@ class TestDistCTR2x2_ASYNC2(TestDistBase):
"FLAGS_communicator_send_queue_size": "2", "FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2", "FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2", "FLAGS_communicator_max_send_grad_num_before_recv": "2",
"FLAGS_communicator_independent_recv_thread": "0" "FLAGS_communicator_independent_recv_thread": "0",
"FLAGS_communicator_is_sgd_optimizer": "0"
} }
self.check_with_place( self.check_with_place(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册