未验证 提交 5baf1b23 编写于 作者: 1 123malin 提交者: GitHub

test=develop, add communicator_is_sgd_optimizer flag (#20677) (#20734)

* test=develop, communicator_is_sgd_optimizer flags
上级 a7d0d888
......@@ -89,6 +89,8 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
VLOG(0) << "communicator_merge_sparse_grad: "
<< FLAGS_communicator_merge_sparse_grad;
VLOG(0) << "communicator_is_sgd_optimizer: "
<< FLAGS_communicator_is_sgd_optimizer;
if (send_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be send, will not start send_thread";
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
......@@ -37,6 +38,8 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle {
namespace operators {
namespace distributed {
......@@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name,
auto in = EigenVector<float>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!FLAGS_communicator_is_sgd_optimizer) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<float>(vars.size());
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto& slr0 = var0->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
......@@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name,
inputs.push_back(&var->Get<framework::SelectedRows>());
}
auto dev_ctx = paddle::platform::CPUDeviceContext();
if (FLAGS_communicator_is_sgd_optimizer) {
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims();
} else {
......
......@@ -42,7 +42,6 @@ TEST(communicator, merge_lod_tensors) {
}
out_value += static_cast<float>(i);
}
out_value = out_value / 10.0;
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
......@@ -96,7 +95,7 @@ TEST(communicator, merge_selected_rows) {
std::vector<float> out_values;
out_values.reserve(10);
for (auto i = 0; i < 10; ++i) {
out_values.push_back(static_cast<float>((i * (10 - i)) / 10.0));
out_values.push_back(static_cast<float>(i * (10 - i)));
}
for (auto i = 0; i < out_slr.rows().size(); ++i) {
ASSERT_EQ(out_slr.rows()[i], i);
......
......@@ -139,6 +139,13 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) {
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader async mode, you need to "
"turn it false.";
}
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
......
......@@ -199,7 +199,9 @@ DEFINE_bool(
*/
DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send");
DEFINE_bool(communicator_is_sgd_optimizer, true,
"gradient sent to the server is the sum of the gradients "
"calculated by each thread if optimizer is sgd");
/**
* Distributed related FLAG
* Name: FLAGS_communicator_send_queue_size
......
......@@ -205,6 +205,7 @@ def __bootstrap__():
read_env_flags.append('communicator_fake_rpc')
read_env_flags.append('communicator_send_wait_times')
read_env_flags.append('communicator_merge_sparse_grad')
read_env_flags.append('communicator_is_sgd_optimizer')
if core.is_compiled_with_brpc():
read_env_flags.append('max_body_size')
#set brpc max body size
......
......@@ -113,7 +113,8 @@ class TestDistCTR2x2_ASYNC2(TestDistBase):
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
"FLAGS_communicator_independent_recv_thread": "0"
"FLAGS_communicator_independent_recv_thread": "0",
"FLAGS_communicator_is_sgd_optimizer": "0"
}
self.check_with_place(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册