未验证 提交 4c55a602 编写于 作者: Q Qiao Longfei 提交者: GitHub

Dist transpiler support prefetch (#9714)

* init

* add some check

* add dist transpile logic

* add insert op for block

* init change get_pserver_program

* optimize code

* fix a bug

* can run now

* start to do table split

* start to process table gradient

* complete pserver part

* can send_vars now

* revert cpplint

* fix a bug

* optimize code

* move dist test to models

* revert the interface of distribute_transpiler.transpile

* fix prefetch_block

* optimize trainspiler code

* add comment to sum_op

* add warning log

* fix comment

* fix test_send_recv

* fix test_send_recv

* fix train with no distributed table

* optimize GetDims
上级 ad73b331
...@@ -92,7 +92,7 @@ class BlockDesc { ...@@ -92,7 +92,7 @@ class BlockDesc {
/* /*
* Remove Op and its input/output variables. * Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or * Note that for either input or output variable, if it is also an input or
* output variable of other ops, we should remain it. * output variable of other ops, we should remain it.
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
......
...@@ -46,7 +46,8 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) { ...@@ -46,7 +46,8 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
} }
} }
static DDim GetDims(const Scope& scope, const std::string& name) { static DDim GetDims(const Scope& scope, const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
return DDim({-1}); return DDim({-1});
...@@ -55,7 +56,11 @@ static DDim GetDims(const Scope& scope, const std::string& name) { ...@@ -55,7 +56,11 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims();
} else {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
}
} else { } else {
return DDim({-1}); return DDim({-1});
} }
...@@ -129,7 +134,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -129,7 +134,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) { for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i]; ss << input.second[i];
if (scope) { if (scope) {
ss << "[" << GetDims(*scope, input.second[i]) << "]"; ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")"; ss << "(" << GetLoD(*scope, input.second[i]) << ")";
} }
if (i != input.second.size() - 1) { if (i != input.second.size() - 1) {
...@@ -149,7 +154,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -149,7 +154,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) { for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i]; ss << output.second[i];
if (scope) { if (scope) {
ss << "[" << GetDims(*scope, output.second[i]) << "]"; ss << "[" << GetDims(*scope, output.second[i], true) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")"; ss << "(" << GetLoD(*scope, output.second[i]) << ")";
} }
if (i != output.second.size() - 1) { if (i != output.second.size() - 1) {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -34,7 +35,10 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -34,7 +35,10 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
const size_t n = ins.size(); const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
if (n == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
auto out_dims = ins[0]; auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
......
...@@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply; ::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname(); std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name); auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope(); framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name); auto* var = local_scope->FindVar(var_name);
......
...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <ostream> #include <ostream>
#include <thread> #include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
...@@ -88,8 +89,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -88,8 +89,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto ins = Inputs("X"); auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
...@@ -97,18 +99,25 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -97,18 +99,25 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::vector<int> block_list; std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != prefetch_block->ID()) {
block_list.push_back(blkid); block_list.push_back(blkid);
} }
auto prepared = executor.Prepare(*program, block_list); }
auto optimize_prepared = executor.Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
prepared.insert(prepared.begin(), optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update // TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor); rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0); VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program); rpc_service_->SetProgram(program);
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
parallel_blkids.push_back(1); parallel_blkids.push_back(1);
double ts = detail::GetTimestamp(); double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (blkid != prefetch_block->ID()) {
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
&recv_scope); program, &recv_scope);
parallel_blkids.clear(); parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent(); last_parent_blkid = program->Block(blkid).Parent();
} }
parallel_blkids.push_back(blkid); parallel_blkids.push_back(blkid);
} }
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, }
&recv_scope); ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
...@@ -211,6 +222,8 @@ from send_op and send back variables to recv_op. ...@@ -211,6 +222,8 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
} }
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <ostream> #include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -27,6 +28,7 @@ namespace paddle { ...@@ -27,6 +28,7 @@ namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service); void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
......
...@@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
"If the value is -1, it makes no effect to lookup. " "If the value is -1, it makes no effect to lookup. "
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <future> #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << "to get " VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << "back"; << outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i], rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i],
outs[i]); outs[i]);
} else { } else {
...@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(RPCClient) The RPC client object which will be" "(RPCClient) The RPC client object which will be"
"initialized at most once."); "initialized at most once.");
AddOutput("Out", AddOutput("Out",
"(SelectedRows) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
.AsDuplicable(); .AsDuplicable();
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include <string> #include <string>
#include <thread> #include <thread> // NOLINT
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -37,11 +37,11 @@ namespace m = paddle::operators::math; ...@@ -37,11 +37,11 @@ namespace m = paddle::operators::math;
std::unique_ptr<f::OperatorBase> listen_and_serv_op; std::unique_ptr<f::OperatorBase> listen_and_serv_op;
int selected_port; int selected_port;
void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
auto var_name = paddle::string::Sprintf("x%d", i); auto var_name = paddle::string::Sprintf("x%d", i);
auto var = scope.Var(var_name); auto var = scope->Var(var_name);
auto tensor = var->GetMutable<f::LoDTensor>(); auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize({10, 10}); tensor->Resize({10, 10});
float *expect = tensor->mutable_data<float>(place); float *expect = tensor->mutable_data<float>(place);
...@@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { ...@@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
} }
} }
auto out_var = scope.Var("Out"); auto out_var = scope->Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>(); auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize({10, 10}); out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate out_tensor->mutable_data<float>(place); // allocate
} }
void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { void InitSelectedRowsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 10; int64_t row_numel = 10;
m::SetConstant<p::CPUDeviceContext, float> set_one; m::SetConstant<p::CPUDeviceContext, float> set_one;
// init x0 // init x0
std::vector<int64_t> rows0{0, 4, 7}; std::vector<int64_t> rows0{0, 4, 7};
auto x0_var = scope.Var("x0"); auto x0_var = scope->Var("x0");
auto x0 = x0_var->GetMutable<f::SelectedRows>(); auto x0 = x0_var->GetMutable<f::SelectedRows>();
x0->set_rows(rows0); x0->set_rows(rows0);
x0->set_height(height); x0->set_height(height);
...@@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { ...@@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
// init x1 // init x1
std::vector<int64_t> rows1{2, 9}; std::vector<int64_t> rows1{2, 9};
auto x1_var = scope.Var("x1"); auto x1_var = scope->Var("x1");
auto x1 = x1_var->GetMutable<f::SelectedRows>(); auto x1 = x1_var->GetMutable<f::SelectedRows>();
x1->set_rows(rows1); x1->set_rows(rows1);
x1->set_height(height); x1->set_height(height);
...@@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { ...@@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place); f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place);
set_one(ctx, x1_value, 1.0); set_one(ctx, x1_value, 1.0);
auto out_var = scope.Var("Out"); auto out_var = scope->Var("Out");
auto out = out_var->GetMutable<f::SelectedRows>(); auto out = out_var->GetMutable<f::SelectedRows>();
auto out_value = out->mutable_value(); auto out_value = out->mutable_value();
out->set_height(height); out->set_height(height);
...@@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) { ...@@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
if (is_sparse) { if (is_sparse) {
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(place, &scope);
} else { } else {
InitTensorsInScope(scope, place); InitTensorsInScope(place, &scope);
} }
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
const auto &root_block = program.Block(0); const auto &root_block = program.Block(0);
auto *optimize_block = program.AppendBlock(root_block); auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
...@@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) { ...@@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
LOG(INFO) << "selected port before run " << selected_port; LOG(INFO) << "selected port before run " << selected_port;
...@@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) { ...@@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) {
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
InitTensorsInScope(scope, place); InitTensorsInScope(place, &scope);
// create rpc client var // create rpc client var
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
...@@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>( selected_port = static_cast<paddle::operators::ListenAndServOp *>(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <future> #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase {
auto ins = Inputs("X"); auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_sent"); int sync_send = Attr<int>("sync_send");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
......
...@@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element"); "Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
// TODO(qijun): check dimensions of Param and Grad at complie // TODO(qijun): check dimensions of Param and Grad at compile
// and run time. // and runtime.
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
......
...@@ -48,20 +48,21 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -48,20 +48,21 @@ class SplitIdsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR);
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims[1], 1);
} }
}
}; };
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR); block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };
...@@ -73,4 +74,5 @@ namespace ops = paddle::operators; ...@@ -73,4 +74,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType); ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>); split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
...@@ -24,14 +24,16 @@ namespace operators { ...@@ -24,14 +24,16 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SplitIdsOpKernel : public framework::OpKernel<T> { class SplitIdsOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) { if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel"); PADDLE_THROW("SplitIds do not support GPU kernel");
} }
auto& ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims(); const auto *ids_var = ctx.InputVar("Ids");
const T* ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>(); if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out"); auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
...@@ -47,14 +49,40 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -47,14 +49,40 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
// create tensor for each shard and send to parameter server // create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) { for (size_t i = 0; i < out_ids.size(); ++i) {
auto* shard_t = outs[i]; auto *shard_t = outs[i];
std::vector<T> ids = out_ids[i]; std::vector<T> ids = out_ids[i];
auto* shard_data = shard_t->mutable_data<T>( auto *shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place); framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i]; shard_data[i] = ids[i];
} }
} }
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), "");
const T *ids = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
// get rows for outputs
for (auto &id : ids_rows) {
size_t shard_id = static_cast<size_t>(id) % shard_num;
outs[shard_id]->mutable_rows()->push_back(id);
}
int64_t row_width = ids_dims[1];
for (auto &out : outs) {
out->set_height(ids_selected_rows->height());
framework::DDim ddim = framework::make_ddim(
{static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (size_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
row_width * sizeof(T));
}
}
}
} }
}; };
......
...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and ...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/operators/sum_op.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -37,7 +39,10 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -37,7 +39,10 @@ class SumOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X"); auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size(); size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
if (N == 1) {
VLOG(3) << "Warning: sum have only one input, may waste memory";
}
framework::DDim in_dim({0}); framework::DDim in_dim({0});
for (auto& x_dim : x_dims) { for (auto& x_dim : x_dims) {
......
...@@ -13,14 +13,17 @@ ...@@ -13,14 +13,17 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import framework
from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer
from layer_helper import LayerHelper
import distributed_splitter as splitter
import math import math
import distributed_splitter as splitter
import framework
from framework import Program, default_main_program, Variable
from . import core from . import core
import debuger
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
class VarBlock: class VarBlock:
...@@ -35,9 +38,9 @@ class VarBlock: ...@@ -35,9 +38,9 @@ class VarBlock:
class UnionFind(object): class UnionFind(object):
""" Union-find data struct. """ Union-find data structure.
Union-find is a data struct that keeps track of a set of elements partitioned Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets. into a number of disjoint (non-overlapping) subsets.
Reference: Reference:
...@@ -185,19 +188,66 @@ class DistributeTranspiler: ...@@ -185,19 +188,66 @@ class DistributeTranspiler:
assert (callable(split_method)) assert (callable(split_method))
if program is None: if program is None:
program = default_main_program() program = default_main_program()
self.program = program self.origin_program = program
self.trainers = trainers self.trainer_num = trainers
self.optimize_ops = optimize_ops self.optimize_ops = optimize_ops
# TODO(typhoonzero): currently trainer_id is fetched from cluster system # TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing # like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance. # fluid distributed training with fault-tolerance.
self.trainer_id = trainer_id self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",") pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
self.has_distributed_lookup_table = len(
distributed_lookup_table_ops) > 0
# step1: For large parameters and gradients, split them into smaller # step1: For large parameters and gradients, split them into smaller
# blocks. # blocks.
param_list = [pg[0] for pg in params_grads] param_list = [pg[0] for pg in params_grads]
grad_list = [pg[1] for pg in params_grads] grad_list = [pg[1] for pg in params_grads]
if self.has_distributed_lookup_table:
param_list = [
param for param in param_list if param.name != self.table_name
]
grad_list = [
grad for grad in grad_list
if grad.name != framework.grad_var_name(self.table_name)
]
self.table_param_grad = [
param_grad for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
table_grad_var = self.table_param_grad[1]
self.table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
# step2: Create new vars for the parameters and gradients blocks and # step2: Create new vars for the parameters and gradients blocks and
...@@ -229,7 +279,7 @@ class DistributeTranspiler: ...@@ -229,7 +279,7 @@ class DistributeTranspiler:
self.param_grad_ep_mapping[ep]["grads"].append(grad) self.param_grad_ep_mapping[ep]["grads"].append(grad)
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR", name=RPC_CLIENT_VAR_NAME,
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
...@@ -252,13 +302,19 @@ class DistributeTranspiler: ...@@ -252,13 +302,19 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops) self.origin_program.global_block().delete_ops(self.optimize_ops)
self.program.sync_with_cpp() self.origin_program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone. # FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__() self.origin_program.__str__()
return self.program return self.origin_program
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
""" """
...@@ -294,8 +350,8 @@ class DistributeTranspiler: ...@@ -294,8 +350,8 @@ class DistributeTranspiler:
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
if self.trainers > 1: if self.trainer_num > 1:
for trainer_id in xrange(self.trainers): for trainer_id in xrange(self.trainer_num):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=False, persistable=False,
...@@ -309,7 +365,7 @@ class DistributeTranspiler: ...@@ -309,7 +365,7 @@ class DistributeTranspiler:
# step3 # step3
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# step 4 # step 4
# Create a union-find data struct from optimize ops, # Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops # If two ops are connected, we could add these two ops
# into one set. # into one set.
ufind = self._create_ufind(self.optimize_ops) ufind = self._create_ufind(self.optimize_ops)
...@@ -384,6 +440,23 @@ class DistributeTranspiler: ...@@ -384,6 +440,23 @@ class DistributeTranspiler:
# __append_optimize_op__(glb_op, optimize_block) # __append_optimize_op__(glb_op, optimize_block)
# break # break
# process distributed lookup_table
prefetch_block = None
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
self._create_table_optimize_block(pserver_index, pserver_program,
append_block)
prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, optimize_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
assert prefetch_block is not None
else:
assert prefetch_block is None
prefetch_block = pserver_program.global_block()
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
...@@ -392,8 +465,10 @@ class DistributeTranspiler: ...@@ -392,8 +465,10 @@ class DistributeTranspiler:
attrs={ attrs={
"OptimizeBlock": optimize_block, "OptimizeBlock": optimize_block,
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainers "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
...@@ -451,6 +526,197 @@ class DistributeTranspiler: ...@@ -451,6 +526,197 @@ class DistributeTranspiler:
attrs=op.attrs) attrs=op.attrs)
return s_prog return s_prog
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
continue_search_lookup_table_op = True
while continue_search_lookup_table_op:
continue_search_lookup_table_op = False
all_ops = program.global_block().ops
for op in all_ops:
if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True
op_index = list(all_ops).index(op)
ids_name = op.input("Ids")
out_name = op.output("Out")
if self.prefetch_input_vars is None:
ids_var = program.global_block().vars[ids_name[0]]
self.prefetch_input_vars = self.create_splited_vars(
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
if self.prefetch_output_vars is None:
out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars(
source_var=out_var,
block=program.global_block(),
tag="_prefetch_out_")
# insert split_ids_op
program.global_block().insert_op(
index=op_index,
type="split_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
]
},
outputs={"Out": self.prefetch_input_vars})
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={
"Out": self.prefetch_output_vars,
"RPCClient": rpc_client_var
},
attrs={"epmap": eplist})
# insert concat_op
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': self.prefetch_output_vars},
outputs={
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
},
attrs={"axis": 0})
# delete lookup_table_op
program.global_block().delete_ops([op])
program.sync_with_cpp()
# break for loop
break
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var,
pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
table_grad_name = framework.grad_var_name(self.table_name)
for op in all_ops:
if table_grad_name in op.output_arg_names:
op_index = list(all_ops).index(op)
# insert split_ids_op
program.global_block().insert_op(
index=op_index + 1,
type="split_ids",
inputs={
'Ids': [program.global_block().vars[table_grad_name]]
},
outputs={"Out": self.table_grad_list})
program.global_block().insert_op(
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True,
"epmap": pserver_endpoints})
break
def _create_prefetch_block(self, pserver_index, pserver_program,
optimize_block):
# STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name]
prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.prefetch_input_vars[pserver_index]
pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name,
type=trainer_ids.type,
shape=trainer_ids.shape,
dtype=trainer_ids.dtype)
trainer_out = self.prefetch_output_vars[pserver_index]
pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name,
type=trainer_out.type,
shape=trainer_out.shape,
dtype=trainer_out.dtype)
prefetch_block.append_op(
type=LOOKUP_TABLE_TYPE,
inputs={'Ids': pserver_ids,
"W": table_var},
outputs={"Out": pserver_out},
attrs={
"is_sparse": True, # has no effect on lookup_table op
"is_distributed": True,
"padding_idx": -1
})
return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program,
append_block):
def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=persistable)
# STEP: create table optimize block
# create table param and grad var in pserver program
param_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[self.table_name])
grad_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name(
self.table_name)],
persistable=False)
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
]
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name
][0]
table_opt_block = pserver_program.create_block(append_block.idx)
# only support sgd now
assert table_opt_op.type == "sgd"
# append sum op for table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]]
inputs = {
"Param": [param_var],
"Grad": [grad_var],
"LearningRate": [lr_var]
}
outputs = {"ParamOut": [param_var]}
table_opt_block.append_op(
type=table_opt_op.type,
inputs=inputs,
outputs=outputs,
attrs=table_opt_op.attrs)
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self, def _create_vars_from_blocklist(self,
program, program,
...@@ -512,7 +778,17 @@ class DistributeTranspiler: ...@@ -512,7 +778,17 @@ class DistributeTranspiler:
program.global_block().sync_with_cpp() program.global_block().sync_with_cpp()
return var_mapping return var_mapping
def _clone_var(self, block, var): def create_splited_vars(self, source_var, block, tag):
return [
block.create_var(
name=str(source_var.name + tag + str(index)),
type=source_var.type,
shape=source_var.shape,
dtype=source_var.dtype)
for index in range(len(self.pserver_endpoints))
]
def _clone_var(self, block, var, persistable=True):
assert isinstance(var, Variable) assert isinstance(var, Variable)
return block.create_var( return block.create_var(
name=var.name, name=var.name,
...@@ -520,12 +796,12 @@ class DistributeTranspiler: ...@@ -520,12 +796,12 @@ class DistributeTranspiler:
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=True) persistable=persistable)
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops # Split variables that need to be split and append respective ops
add_suffix = False add_suffix = False
if self.trainers > 1: if self.trainer_num > 1:
add_suffix = True add_suffix = True
var_mapping = self._create_vars_from_blocklist( var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=add_suffix) program, gradblocks, add_trainer_suffix=add_suffix)
...@@ -616,9 +892,9 @@ class DistributeTranspiler: ...@@ -616,9 +892,9 @@ class DistributeTranspiler:
return return
merged_var = \ merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)] pserver_block.vars[self._orig_varname(grad_block.name)]
if self.trainers > 1: if self.trainer_num > 1:
vars2merge = [] vars2merge = []
for i in xrange(self.trainers): for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \ per_trainer_name = "%s.trainer_%d" % \
(self._orig_varname(grad_block.name), i) (self._orig_varname(grad_block.name), i)
vars2merge.append(pserver_block.vars[per_trainer_name]) vars2merge.append(pserver_block.vars[per_trainer_name])
...@@ -633,7 +909,7 @@ class DistributeTranspiler: ...@@ -633,7 +909,7 @@ class DistributeTranspiler:
type="scale", type="scale",
inputs={"X": merged_var}, inputs={"X": merged_var},
outputs={"Out": merged_var}, outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainers)}) attrs={"scale": 1.0 / float(self.trainer_num)})
new_inputs[key] = merged_var new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
# param is already created on global program # param is already created on global program
...@@ -669,7 +945,7 @@ class DistributeTranspiler: ...@@ -669,7 +945,7 @@ class DistributeTranspiler:
new_shape = None new_shape = None
if key in ["Param", "Grad", "LearningRate"]: if key in ["Param", "Grad", "LearningRate"]:
continue continue
var = self.program.global_block().vars[opt_op.input(key)[0]] var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
...@@ -682,8 +958,8 @@ class DistributeTranspiler: ...@@ -682,8 +958,8 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# change output's ParamOut variable # change output's ParamOut variable
outputs = self._get_output_map_from_op(self.program.global_block().vars, outputs = self._get_output_map_from_op(
opt_op) self.origin_program.global_block().vars, opt_op)
outputs["ParamOut"] = new_inputs["Param"] outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op( optimize_block.append_op(
...@@ -695,8 +971,8 @@ class DistributeTranspiler: ...@@ -695,8 +971,8 @@ class DistributeTranspiler:
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(self.program.global_block().vars, inputs = self._get_input_map_from_op(
opt_op) self.origin_program.global_block().vars, opt_op)
for varlist in inputs.itervalues(): for varlist in inputs.itervalues():
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
...@@ -709,8 +985,8 @@ class DistributeTranspiler: ...@@ -709,8 +985,8 @@ class DistributeTranspiler:
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
outputs = self._get_output_map_from_op(self.program.global_block().vars, outputs = self._get_output_map_from_op(
opt_op) self.origin_program.global_block().vars, opt_op)
for varlist in outputs.itervalues(): for varlist in outputs.itervalues():
if not isinstance(varlist, list): if not isinstance(varlist, list):
...@@ -783,7 +1059,6 @@ class DistributeTranspiler: ...@@ -783,7 +1059,6 @@ class DistributeTranspiler:
if same_or_split_var(n, param) and n != param: if same_or_split_var(n, param) and n != param:
return True return True
return False return False
return False
def _get_input_map_from_op(self, varmap, op): def _get_input_map_from_op(self, varmap, op):
"""Returns a dict from op input name to the vars in varmap.""" """Returns a dict from op input name to the vars in varmap."""
...@@ -821,7 +1096,7 @@ class DistributeTranspiler: ...@@ -821,7 +1096,7 @@ class DistributeTranspiler:
find_ops = [] find_ops = []
# find ops which output is lr var # find ops which output is lr var
block = self.program.global_block() block = self.origin_program.global_block()
for op in block.ops: for op in block.ops:
if set(op.output_arg_names) & lr_vars: if set(op.output_arg_names) & lr_vars:
find_ops.append(op) find_ops.append(op)
......
...@@ -218,6 +218,7 @@ def fc(input, ...@@ -218,6 +218,7 @@ def fc(input,
def embedding(input, def embedding(input,
size, size,
is_sparse=False, is_sparse=False,
is_distributed=False,
padding_idx=None, padding_idx=None,
param_attr=None, param_attr=None,
dtype='float32'): dtype='float32'):
...@@ -268,8 +269,11 @@ def embedding(input, ...@@ -268,8 +269,11 @@ def embedding(input,
inputs={'Ids': input, inputs={'Ids': input,
'W': w}, 'W': w},
outputs={'Out': tmp}, outputs={'Out': tmp},
attrs={'is_sparse': is_sparse, attrs={
'padding_idx': padding_idx}) 'is_sparse': is_sparse,
'is_distributed': is_distributed,
'padding_idx': padding_idx
})
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册