提交 315e44ac 编写于 作者: Y Yancey1989

add fetch_barrier_op

上级 b35ea1a4
...@@ -353,7 +353,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -353,7 +353,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} else { } else {
// Delete the local scopes created in operators. // Delete the local scopes created in operators.
scope->DropKids(); // scope->DropKids();
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------"; VLOG(2) << "-------------------------------------------------------";
......
...@@ -199,11 +199,13 @@ if(WITH_DISTRIBUTE) ...@@ -199,11 +199,13 @@ if(WITH_DISTRIBUTE)
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
else() else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
namespace paddle {
namespace operators {
class FetchBarrierOp : public framework::OperatorBase {
public:
FetchBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
};
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class FetchBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
ops::FetchBarrierOpVarTypeInference,
ops::FetchBarrierOpShapeInference);
...@@ -315,12 +315,22 @@ class DistributeTranspiler: ...@@ -315,12 +315,22 @@ class DistributeTranspiler:
# step 3.1: insert send op to send gradient vars to parameter servers # step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset() ps_dispatcher.reset()
send_vars = [] send_vars = []
for varname, splited_vars in grad_var_mapping.items(): for orig_varname, splited_vars in grad_var_mapping.items():
index = find_op_by_output_arg(program.global_block(), varname)
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) > 1: if len(splited_vars) == 1:
self._insert_split_op(program, varname, splited_vars) orig_varname = splited_vars[0].name
index = find_op_by_output_arg(program.global_block(),
orig_varname)
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[orig_varname]
index = find_op_by_output_arg(program.global_block(),
orig_varname)
self._insert_split_op(program, orig_var, index, splited_vars)
index += 1 index += 1
else:
AssertionError("Can not insert the send op by original "
"variable name :", orig_varname)
program.global_block().insert_op( program.global_block().insert_op(
index=index + 1, index=index + 1,
type="send_vars", type="send_vars",
...@@ -351,6 +361,12 @@ class DistributeTranspiler: ...@@ -351,6 +361,12 @@ class DistributeTranspiler:
"RPCClient": rpc_client_var}, "RPCClient": rpc_client_var},
attrs={"epmap": eplist}) attrs={"epmap": eplist})
program.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints})
for i, ep in enumerate(eplist): for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
...@@ -859,9 +875,7 @@ class DistributeTranspiler: ...@@ -859,9 +875,7 @@ class DistributeTranspiler:
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=persistable) persistable=persistable)
def _insert_split_op(self, program, orig_varname, splited_vars): def _insert_split_op(self, program, orig_var, index, splited_vars):
orig_var = program.global_block().vars[orig_varname]
index = find_op_by_output_arg(program.global_block(), orig_varname)
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
...@@ -887,45 +901,6 @@ class DistributeTranspiler: ...@@ -887,45 +901,6 @@ class DistributeTranspiler:
AssertionError("Variable type should be in set " AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]") "[LOD_TENSOR, SELECTED_ROWS]")
def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops
add_suffix = False
if self.trainer_num > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1:
continue
orig_var = program.global_block().vars[varname]
index = find_op_by_output_arg(program.global_block(), orig_var.name)
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
program.global_block().insert_op(
index=index + 1,
type="split_selected_rows",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"height_sections": height_sections})
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = []
for v in splited_vars:
sections.append(v.shape[0])
program.global_block().insert_op(
index=index + 1,
type="split_byref",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"sections": sections} # assume split evenly
)
else:
AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]")
return var_mapping
def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
param_shape): param_shape):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册