未验证 提交 4c55a602 编写于 作者: Q Qiao Longfei 提交者: GitHub

Dist transpiler support prefetch (#9714)

* init

* add some check

* add dist transpile logic

* add insert op for block

* init change get_pserver_program

* optimize code

* fix a bug

* can run now

* start to do table split

* start to process table gradient

* complete pserver part

* can send_vars now

* revert cpplint

* fix a bug

* optimize code

* move dist test to models

* revert the interface of distribute_transpiler.transpile

* fix prefetch_block

* optimize trainspiler code

* add comment to sum_op

* add warning log

* fix comment

* fix test_send_recv

* fix test_send_recv

* fix train with no distributed table

* optimize GetDims
上级 ad73b331
......@@ -92,7 +92,7 @@ class BlockDesc {
/*
* Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or
* Note that for either input or output variable, if it is also an input or
* output variable of other ops, we should remain it.
*/
void RemoveOp(size_t s, size_t e);
......
......@@ -46,7 +46,8 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
}
}
static DDim GetDims(const Scope& scope, const std::string& name) {
static DDim GetDims(const Scope& scope, const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return DDim({-1});
......@@ -55,7 +56,11 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims();
} else {
return var->Get<SelectedRows>().GetCompleteDims();
}
} else {
return DDim({-1});
}
......@@ -129,7 +134,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (scope) {
ss << "[" << GetDims(*scope, input.second[i]) << "]";
ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
}
if (i != input.second.size() - 1) {
......@@ -149,7 +154,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (scope) {
ss << "[" << GetDims(*scope, output.second[i]) << "]";
ss << "[" << GetDims(*scope, output.second[i], true) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")";
}
if (i != output.second.size() - 1) {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <string>
#include <vector>
......@@ -34,7 +35,10 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
if (n == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size();
......
......@@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
......
......@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <thread>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/listen_and_serv_op.h"
......@@ -88,8 +89,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
......@@ -97,18 +99,25 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place);
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
if (blkid != prefetch_block->ID()) {
block_list.push_back(blkid);
}
}
auto prepared = executor.Prepare(*program, block_list);
auto optimize_prepared = executor.Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself.
prepared.insert(prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......@@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
parallel_blkids.push_back(1);
double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
&recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
if (blkid != prefetch_block->ID()) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
parallel_blkids.push_back(blkid);
}
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
&recv_scope);
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
......@@ -211,6 +222,8 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -27,6 +28,7 @@ namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
......
......@@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future>
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
......@@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << "to get "
<< outs[i] << "back";
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i],
outs[i]);
} else {
......@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out",
"(SelectedRows) result "
"(LoDTensor) result "
"to be fetched from parameter server")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -37,11 +37,11 @@ namespace m = paddle::operators::math;
std::unique_ptr<f::OperatorBase> listen_and_serv_op;
int selected_port;
void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place);
for (int i = 0; i < 2; ++i) {
auto var_name = paddle::string::Sprintf("x%d", i);
auto var = scope.Var(var_name);
auto var = scope->Var(var_name);
auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize({10, 10});
float *expect = tensor->mutable_data<float>(place);
......@@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
}
}
auto out_var = scope.Var("Out");
auto out_var = scope->Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate
}
void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
void InitSelectedRowsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place);
int64_t height = 10;
int64_t row_numel = 10;
m::SetConstant<p::CPUDeviceContext, float> set_one;
// init x0
std::vector<int64_t> rows0{0, 4, 7};
auto x0_var = scope.Var("x0");
auto x0_var = scope->Var("x0");
auto x0 = x0_var->GetMutable<f::SelectedRows>();
x0->set_rows(rows0);
x0->set_height(height);
......@@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
// init x1
std::vector<int64_t> rows1{2, 9};
auto x1_var = scope.Var("x1");
auto x1_var = scope->Var("x1");
auto x1 = x1_var->GetMutable<f::SelectedRows>();
x1->set_rows(rows1);
x1->set_height(height);
......@@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place);
set_one(ctx, x1_value, 1.0);
auto out_var = scope.Var("Out");
auto out_var = scope->Var("Out");
auto out = out_var->GetMutable<f::SelectedRows>();
auto out_value = out->mutable_value();
out->set_height(height);
......@@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) {
f::Scope scope;
p::CPUPlace place;
if (is_sparse) {
InitSelectedRowsInScope(scope, place);
InitSelectedRowsInScope(place, &scope);
} else {
InitTensorsInScope(scope, place);
InitTensorsInScope(place, &scope);
}
// sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program;
const auto &root_block = program.Block(0);
auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
......@@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
LOG(INFO) << "selected port before run " << selected_port;
......@@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) {
// local net
f::Scope scope;
p::CPUPlace place;
InitTensorsInScope(scope, place);
InitTensorsInScope(place, &scope);
// create rpc client var
scope.Var("RPC_CLIENT_VAR");
......@@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) {
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place);
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>(
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future>
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
......@@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase {
auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_sent");
int sync_send = Attr<int>("sync_send");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......
......@@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
// TODO(qijun): check dimensions of Param and Grad at complie
// and run time.
// TODO(qijun): check dimensions of Param and Grad at compile
// and runtime.
ctx->SetOutputDim("ParamOut", param_dim);
}
......
......@@ -48,11 +48,11 @@ class SplitIdsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR);
auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
}
};
......@@ -60,8 +60,9 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(out_var)->SetType(input_var->GetType());
}
}
};
......@@ -73,4 +74,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>);
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
......@@ -24,35 +24,63 @@ namespace operators {
template <typename DeviceContext, typename T>
class SplitIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel");
}
auto& ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T* ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
const auto *ids_var = ctx.InputVar("Ids");
if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto *shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto *shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
}
}
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), "");
const T *ids = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
// get rows for outputs
for (auto &id : ids_rows) {
size_t shard_id = static_cast<size_t>(id) % shard_num;
outs[shard_id]->mutable_rows()->push_back(id);
}
// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto* shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto* shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
int64_t row_width = ids_dims[1];
for (auto &out : outs) {
out->set_height(ids_selected_rows->height());
framework::DDim ddim = framework::make_ddim(
{static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (size_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
row_width * sizeof(T));
}
}
}
}
......
......@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
......@@ -37,7 +39,10 @@ class SumOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
if (N == 1) {
VLOG(3) << "Warning: sum have only one input, may waste memory";
}
framework::DDim in_dim({0});
for (auto& x_dim : x_dims) {
......
......@@ -13,14 +13,17 @@
# limitations under the License.
from __future__ import print_function
import framework
from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer
from layer_helper import LayerHelper
import distributed_splitter as splitter
import math
import distributed_splitter as splitter
import framework
from framework import Program, default_main_program, Variable
from . import core
import debuger
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
class VarBlock:
......@@ -35,9 +38,9 @@ class VarBlock:
class UnionFind(object):
""" Union-find data struct.
""" Union-find data structure.
Union-find is a data struct that keeps track of a set of elements partitioned
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
......@@ -185,19 +188,66 @@ class DistributeTranspiler:
assert (callable(split_method))
if program is None:
program = default_main_program()
self.program = program
self.trainers = trainers
self.origin_program = program
self.trainer_num = trainers
self.optimize_ops = optimize_ops
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
self.has_distributed_lookup_table = len(
distributed_lookup_table_ops) > 0
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list = [pg[0] for pg in params_grads]
grad_list = [pg[1] for pg in params_grads]
if self.has_distributed_lookup_table:
param_list = [
param for param in param_list if param.name != self.table_name
]
grad_list = [
grad for grad in grad_list
if grad.name != framework.grad_var_name(self.table_name)
]
self.table_param_grad = [
param_grad for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
table_grad_var = self.table_param_grad[1]
self.table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
# step2: Create new vars for the parameters and gradients blocks and
......@@ -229,7 +279,7 @@ class DistributeTranspiler:
self.param_grad_ep_mapping[ep]["grads"].append(grad)
rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR",
name=RPC_CLIENT_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
......@@ -252,13 +302,19 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]},
attrs={"axis": 0})
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
self.program.sync_with_cpp()
self.origin_program.global_block().delete_ops(self.optimize_ops)
self.origin_program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program
self.origin_program.__str__()
return self.origin_program
def get_pserver_program(self, endpoint):
"""
......@@ -294,8 +350,8 @@ class DistributeTranspiler:
type=v.type,
dtype=v.dtype,
shape=v.shape)
if self.trainers > 1:
for trainer_id in xrange(self.trainers):
if self.trainer_num > 1:
for trainer_id in xrange(self.trainer_num):
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=False,
......@@ -309,7 +365,7 @@ class DistributeTranspiler:
# step3
optimize_block = pserver_program.create_block(0)
# step 4
# Create a union-find data struct from optimize ops,
# Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind = self._create_ufind(self.optimize_ops)
......@@ -384,6 +440,23 @@ class DistributeTranspiler:
# __append_optimize_op__(glb_op, optimize_block)
# break
# process distributed lookup_table
prefetch_block = None
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
self._create_table_optimize_block(pserver_index, pserver_program,
append_block)
prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, optimize_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
assert prefetch_block is not None
else:
assert prefetch_block is None
prefetch_block = pserver_program.global_block()
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
......@@ -392,8 +465,10 @@ class DistributeTranspiler:
attrs={
"OptimizeBlock": optimize_block,
"endpoint": endpoint,
"Fanin": self.trainers
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block
})
pserver_program.sync_with_cpp()
return pserver_program
......@@ -451,6 +526,197 @@ class DistributeTranspiler:
attrs=op.attrs)
return s_prog
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
continue_search_lookup_table_op = True
while continue_search_lookup_table_op:
continue_search_lookup_table_op = False
all_ops = program.global_block().ops
for op in all_ops:
if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True
op_index = list(all_ops).index(op)
ids_name = op.input("Ids")
out_name = op.output("Out")
if self.prefetch_input_vars is None:
ids_var = program.global_block().vars[ids_name[0]]
self.prefetch_input_vars = self.create_splited_vars(
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
if self.prefetch_output_vars is None:
out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars(
source_var=out_var,
block=program.global_block(),
tag="_prefetch_out_")
# insert split_ids_op
program.global_block().insert_op(
index=op_index,
type="split_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
]
},
outputs={"Out": self.prefetch_input_vars})
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={
"Out": self.prefetch_output_vars,
"RPCClient": rpc_client_var
},
attrs={"epmap": eplist})
# insert concat_op
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': self.prefetch_output_vars},
outputs={
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
},
attrs={"axis": 0})
# delete lookup_table_op
program.global_block().delete_ops([op])
program.sync_with_cpp()
# break for loop
break
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var,
pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
table_grad_name = framework.grad_var_name(self.table_name)
for op in all_ops:
if table_grad_name in op.output_arg_names:
op_index = list(all_ops).index(op)
# insert split_ids_op
program.global_block().insert_op(
index=op_index + 1,
type="split_ids",
inputs={
'Ids': [program.global_block().vars[table_grad_name]]
},
outputs={"Out": self.table_grad_list})
program.global_block().insert_op(
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True,
"epmap": pserver_endpoints})
break
def _create_prefetch_block(self, pserver_index, pserver_program,
optimize_block):
# STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name]
prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.prefetch_input_vars[pserver_index]
pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name,
type=trainer_ids.type,
shape=trainer_ids.shape,
dtype=trainer_ids.dtype)
trainer_out = self.prefetch_output_vars[pserver_index]
pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name,
type=trainer_out.type,
shape=trainer_out.shape,
dtype=trainer_out.dtype)
prefetch_block.append_op(
type=LOOKUP_TABLE_TYPE,
inputs={'Ids': pserver_ids,
"W": table_var},
outputs={"Out": pserver_out},
attrs={
"is_sparse": True, # has no effect on lookup_table op
"is_distributed": True,
"padding_idx": -1
})
return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program,
append_block):
def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=persistable)
# STEP: create table optimize block
# create table param and grad var in pserver program
param_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[self.table_name])
grad_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name(
self.table_name)],
persistable=False)
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
]
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name
][0]
table_opt_block = pserver_program.create_block(append_block.idx)
# only support sgd now
assert table_opt_op.type == "sgd"
# append sum op for table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]]
inputs = {
"Param": [param_var],
"Grad": [grad_var],
"LearningRate": [lr_var]
}
outputs = {"ParamOut": [param_var]}
table_opt_block.append_op(
type=table_opt_op.type,
inputs=inputs,
outputs=outputs,
attrs=table_opt_op.attrs)
# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self,
program,
......@@ -512,7 +778,17 @@ class DistributeTranspiler:
program.global_block().sync_with_cpp()
return var_mapping
def _clone_var(self, block, var):
def create_splited_vars(self, source_var, block, tag):
return [
block.create_var(
name=str(source_var.name + tag + str(index)),
type=source_var.type,
shape=source_var.shape,
dtype=source_var.dtype)
for index in range(len(self.pserver_endpoints))
]
def _clone_var(self, block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
......@@ -520,12 +796,12 @@ class DistributeTranspiler:
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
persistable=persistable)
def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops
add_suffix = False
if self.trainers > 1:
if self.trainer_num > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=add_suffix)
......@@ -616,9 +892,9 @@ class DistributeTranspiler:
return
merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)]
if self.trainers > 1:
if self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainers):
for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \
(self._orig_varname(grad_block.name), i)
vars2merge.append(pserver_block.vars[per_trainer_name])
......@@ -633,7 +909,7 @@ class DistributeTranspiler:
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainers)})
attrs={"scale": 1.0 / float(self.trainer_num)})
new_inputs[key] = merged_var
elif key == "Param":
# param is already created on global program
......@@ -669,7 +945,7 @@ class DistributeTranspiler:
new_shape = None
if key in ["Param", "Grad", "LearningRate"]:
continue
var = self.program.global_block().vars[opt_op.input(key)[0]]
var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
......@@ -682,8 +958,8 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar
# change output's ParamOut variable
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op(
......@@ -695,8 +971,8 @@ class DistributeTranspiler:
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op)
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op)
for varlist in inputs.itervalues():
if not isinstance(varlist, list):
varlist = [varlist]
......@@ -709,8 +985,8 @@ class DistributeTranspiler:
dtype=var.dtype,
shape=var.shape)
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for varlist in outputs.itervalues():
if not isinstance(varlist, list):
......@@ -783,7 +1059,6 @@ class DistributeTranspiler:
if same_or_split_var(n, param) and n != param:
return True
return False
return False
def _get_input_map_from_op(self, varmap, op):
"""Returns a dict from op input name to the vars in varmap."""
......@@ -821,7 +1096,7 @@ class DistributeTranspiler:
find_ops = []
# find ops which output is lr var
block = self.program.global_block()
block = self.origin_program.global_block()
for op in block.ops:
if set(op.output_arg_names) & lr_vars:
find_ops.append(op)
......
......@@ -218,6 +218,7 @@ def fc(input,
def embedding(input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
......@@ -268,8 +269,11 @@ def embedding(input,
inputs={'Ids': input,
'W': w},
outputs={'Out': tmp},
attrs={'is_sparse': is_sparse,
'padding_idx': padding_idx})
attrs={
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'padding_idx': padding_idx
})
return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册