未验证 提交 4c55a602 编写于 作者: Q Qiao Longfei 提交者: GitHub

Dist transpiler support prefetch (#9714)

* init

* add some check

* add dist transpile logic

* add insert op for block

* init change get_pserver_program

* optimize code

* fix a bug

* can run now

* start to do table split

* start to process table gradient

* complete pserver part

* can send_vars now

* revert cpplint

* fix a bug

* optimize code

* move dist test to models

* revert the interface of distribute_transpiler.transpile

* fix prefetch_block

* optimize trainspiler code

* add comment to sum_op

* add warning log

* fix comment

* fix test_send_recv

* fix test_send_recv

* fix train with no distributed table

* optimize GetDims
上级 ad73b331
...@@ -92,7 +92,7 @@ class BlockDesc { ...@@ -92,7 +92,7 @@ class BlockDesc {
/* /*
* Remove Op and its input/output variables. * Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or * Note that for either input or output variable, if it is also an input or
* output variable of other ops, we should remain it. * output variable of other ops, we should remain it.
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
......
...@@ -46,7 +46,8 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) { ...@@ -46,7 +46,8 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
} }
} }
static DDim GetDims(const Scope& scope, const std::string& name) { static DDim GetDims(const Scope& scope, const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
return DDim({-1}); return DDim({-1});
...@@ -55,7 +56,11 @@ static DDim GetDims(const Scope& scope, const std::string& name) { ...@@ -55,7 +56,11 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims();
} else {
return var->Get<SelectedRows>().GetCompleteDims();
}
} else { } else {
return DDim({-1}); return DDim({-1});
} }
...@@ -129,7 +134,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -129,7 +134,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) { for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i]; ss << input.second[i];
if (scope) { if (scope) {
ss << "[" << GetDims(*scope, input.second[i]) << "]"; ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")"; ss << "(" << GetLoD(*scope, input.second[i]) << ")";
} }
if (i != input.second.size() - 1) { if (i != input.second.size() - 1) {
...@@ -149,7 +154,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -149,7 +154,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) { for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i]; ss << output.second[i];
if (scope) { if (scope) {
ss << "[" << GetDims(*scope, output.second[i]) << "]"; ss << "[" << GetDims(*scope, output.second[i], true) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")"; ss << "(" << GetLoD(*scope, output.second[i]) << ")";
} }
if (i != output.second.size() - 1) { if (i != output.second.size() - 1) {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -34,7 +35,10 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -34,7 +35,10 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
const size_t n = ins.size(); const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
if (n == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
auto out_dims = ins[0]; auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
......
...@@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply; ::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname(); std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name); auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope(); framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name); auto* var = local_scope->FindVar(var_name);
......
...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <ostream> #include <ostream>
#include <thread> #include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
...@@ -88,8 +89,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -88,8 +89,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto ins = Inputs("X"); auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
...@@ -97,18 +99,25 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -97,18 +99,25 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::vector<int> block_list; std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid); if (blkid != prefetch_block->ID()) {
block_list.push_back(blkid);
}
} }
auto prepared = executor.Prepare(*program, block_list); auto optimize_prepared = executor.Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
prepared.insert(prepared.begin(), optimize_prepared.insert(
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update // TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor); rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0); VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program); rpc_service_->SetProgram(program);
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
parallel_blkids.push_back(1); parallel_blkids.push_back(1);
double ts = detail::GetTimestamp(); double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (program->Block(blkid).Parent() != last_parent_blkid) { if (blkid != prefetch_block->ID()) {
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, if (program->Block(blkid).Parent() != last_parent_blkid) {
&recv_scope); ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
parallel_blkids.clear(); program, &recv_scope);
last_parent_blkid = program->Block(blkid).Parent(); parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
} }
parallel_blkids.push_back(blkid);
} }
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program, ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
&recv_scope); program, &recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
...@@ -211,6 +222,8 @@ from send_op and send back variables to recv_op. ...@@ -211,6 +222,8 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
} }
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <ostream> #include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -27,6 +28,7 @@ namespace paddle { ...@@ -27,6 +28,7 @@ namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service); void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
......
...@@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
"If the value is -1, it makes no effect to lookup. " "If the value is -1, it makes no effect to lookup. "
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <future> #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << "to get " VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << "back"; << outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i], rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i],
outs[i]); outs[i]);
} else { } else {
...@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(RPCClient) The RPC client object which will be" "(RPCClient) The RPC client object which will be"
"initialized at most once."); "initialized at most once.");
AddOutput("Out", AddOutput("Out",
"(SelectedRows) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
.AsDuplicable(); .AsDuplicable();
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include <string> #include <string>
#include <thread> #include <thread> // NOLINT
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -37,11 +37,11 @@ namespace m = paddle::operators::math; ...@@ -37,11 +37,11 @@ namespace m = paddle::operators::math;
std::unique_ptr<f::OperatorBase> listen_and_serv_op; std::unique_ptr<f::OperatorBase> listen_and_serv_op;
int selected_port; int selected_port;
void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
auto var_name = paddle::string::Sprintf("x%d", i); auto var_name = paddle::string::Sprintf("x%d", i);
auto var = scope.Var(var_name); auto var = scope->Var(var_name);
auto tensor = var->GetMutable<f::LoDTensor>(); auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize({10, 10}); tensor->Resize({10, 10});
float *expect = tensor->mutable_data<float>(place); float *expect = tensor->mutable_data<float>(place);
...@@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { ...@@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
} }
} }
auto out_var = scope.Var("Out"); auto out_var = scope->Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>(); auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize({10, 10}); out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate out_tensor->mutable_data<float>(place); // allocate
} }
void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { void InitSelectedRowsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 10; int64_t row_numel = 10;
m::SetConstant<p::CPUDeviceContext, float> set_one; m::SetConstant<p::CPUDeviceContext, float> set_one;
// init x0 // init x0
std::vector<int64_t> rows0{0, 4, 7}; std::vector<int64_t> rows0{0, 4, 7};
auto x0_var = scope.Var("x0"); auto x0_var = scope->Var("x0");
auto x0 = x0_var->GetMutable<f::SelectedRows>(); auto x0 = x0_var->GetMutable<f::SelectedRows>();
x0->set_rows(rows0); x0->set_rows(rows0);
x0->set_height(height); x0->set_height(height);
...@@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { ...@@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
// init x1 // init x1
std::vector<int64_t> rows1{2, 9}; std::vector<int64_t> rows1{2, 9};
auto x1_var = scope.Var("x1"); auto x1_var = scope->Var("x1");
auto x1 = x1_var->GetMutable<f::SelectedRows>(); auto x1 = x1_var->GetMutable<f::SelectedRows>();
x1->set_rows(rows1); x1->set_rows(rows1);
x1->set_height(height); x1->set_height(height);
...@@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { ...@@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place); f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place);
set_one(ctx, x1_value, 1.0); set_one(ctx, x1_value, 1.0);
auto out_var = scope.Var("Out"); auto out_var = scope->Var("Out");
auto out = out_var->GetMutable<f::SelectedRows>(); auto out = out_var->GetMutable<f::SelectedRows>();
auto out_value = out->mutable_value(); auto out_value = out->mutable_value();
out->set_height(height); out->set_height(height);
...@@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) { ...@@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
if (is_sparse) { if (is_sparse) {
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(place, &scope);
} else { } else {
InitTensorsInScope(scope, place); InitTensorsInScope(place, &scope);
} }
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
const auto &root_block = program.Block(0); const auto &root_block = program.Block(0);
auto *optimize_block = program.AppendBlock(root_block); auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
...@@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) { ...@@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
LOG(INFO) << "selected port before run " << selected_port; LOG(INFO) << "selected port before run " << selected_port;
...@@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) { ...@@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) {
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
InitTensorsInScope(scope, place); InitTensorsInScope(place, &scope);
// create rpc client var // create rpc client var
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
...@@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>( selected_port = static_cast<paddle::operators::ListenAndServOp *>(
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <future> #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase {
auto ins = Inputs("X"); auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_sent"); int sync_send = Attr<int>("sync_send");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
......
...@@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element"); "Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
// TODO(qijun): check dimensions of Param and Grad at complie // TODO(qijun): check dimensions of Param and Grad at compile
// and run time. // and runtime.
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
......
...@@ -48,11 +48,11 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -48,11 +48,11 @@ class SplitIdsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR);
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
} }
}; };
...@@ -60,8 +60,9 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference { ...@@ -60,8 +60,9 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR); block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };
...@@ -73,4 +74,5 @@ namespace ops = paddle::operators; ...@@ -73,4 +74,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType); ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>); split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
...@@ -24,35 +24,63 @@ namespace operators { ...@@ -24,35 +24,63 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SplitIdsOpKernel : public framework::OpKernel<T> { class SplitIdsOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) { if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel"); PADDLE_THROW("SplitIds do not support GPU kernel");
} }
auto& ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims(); const auto *ids_var = ctx.InputVar("Ids");
const T* ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>(); if (ids_var->IsType<framework::LoDTensor>()) {
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out"); const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const size_t shard_num = outs.size(); const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids; std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size()); out_ids.resize(outs.size());
// split id by their shard_num. // split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) { for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i]; T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num; size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id); out_ids[shard_id].push_back(id);
} }
// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto *shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto *shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
}
}
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), "");
const T *ids = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
// get rows for outputs
for (auto &id : ids_rows) {
size_t shard_id = static_cast<size_t>(id) % shard_num;
outs[shard_id]->mutable_rows()->push_back(id);
}
// create tensor for each shard and send to parameter server int64_t row_width = ids_dims[1];
for (size_t i = 0; i < out_ids.size(); ++i) { for (auto &out : outs) {
auto* shard_t = outs[i]; out->set_height(ids_selected_rows->height());
std::vector<T> ids = out_ids[i]; framework::DDim ddim = framework::make_ddim(
auto* shard_data = shard_t->mutable_data<T>( {static_cast<int64_t>(out->rows().size()), row_width});
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place); T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ddim[0]; ++i) {
shard_data[i] = ids[i]; memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
row_width * sizeof(T));
}
} }
} }
} }
......
...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and ...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/operators/sum_op.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -37,7 +39,10 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -37,7 +39,10 @@ class SumOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X"); auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size(); size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
if (N == 1) {
VLOG(3) << "Warning: sum have only one input, may waste memory";
}
framework::DDim in_dim({0}); framework::DDim in_dim({0});
for (auto& x_dim : x_dims) { for (auto& x_dim : x_dims) {
......
...@@ -218,6 +218,7 @@ def fc(input, ...@@ -218,6 +218,7 @@ def fc(input,
def embedding(input, def embedding(input,
size, size,
is_sparse=False, is_sparse=False,
is_distributed=False,
padding_idx=None, padding_idx=None,
param_attr=None, param_attr=None,
dtype='float32'): dtype='float32'):
...@@ -268,8 +269,11 @@ def embedding(input, ...@@ -268,8 +269,11 @@ def embedding(input,
inputs={'Ids': input, inputs={'Ids': input,
'W': w}, 'W': w},
outputs={'Out': tmp}, outputs={'Out': tmp},
attrs={'is_sparse': is_sparse, attrs={
'padding_idx': padding_idx}) 'is_sparse': is_sparse,
'is_distributed': is_distributed,
'padding_idx': padding_idx
})
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册