提交 0c321fe3 编写于 作者: _青葱's avatar _青葱

Merge branch develop

......@@ -92,7 +92,7 @@ class BlockDesc {
/*
* Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or
* Note that for either input or output variable, if it is also an input or
* output variable of other ops, we should remain it.
*/
void RemoveOp(size_t s, size_t e);
......
......@@ -14,6 +14,8 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
......@@ -33,7 +35,7 @@ void ComputationOpHandle::RunImpl() {
}
}
op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}
std::string ComputationOpHandle::Name() const { return op_->Type(); }
......
......@@ -14,6 +14,9 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
......@@ -57,7 +60,10 @@ void FetchOpHandle::RunImpl() {
for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
auto &t = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_name)
->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
......
......@@ -24,6 +24,8 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
class OpHandleBase {
private:
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
......
......@@ -15,13 +15,15 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
namespace paddle {
namespace framework {
namespace details {
class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
......
......@@ -136,12 +136,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
ready_ops.clear();
};
// Create local scopes.
for (auto &scope : local_scopes_) {
auto &local_scope = scope->NewScope();
*scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
}
// Step 3. Execution
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
// 1. Run All Ready ops
......@@ -189,34 +183,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
++computation_count_;
auto sync_computation = [&] {
computation_count_ = 0;
// Wait All computational streams
for (auto p : this->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
scope->DropKids();
}
};
// Wait FetchOps.
if (!fetch_ops.empty()) {
fetch_ops.clear();
sync_computation();
}
if (computation_count_ == max_async_computation) {
sync_computation();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for (auto &scope : local_scopes_) {
auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
kid = nullptr;
}
return fetch_data;
......
......@@ -99,9 +99,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
bool allow_op_delay_;
size_t computation_count_{0};
size_t max_async_computation{100};
};
} // namespace details
......
......@@ -46,7 +46,8 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
}
}
static DDim GetDims(const Scope& scope, const std::string& name) {
static DDim GetDims(const Scope& scope, const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return DDim({-1});
......@@ -55,7 +56,11 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims();
} else {
return var->Get<SelectedRows>().GetCompleteDims();
}
} else {
return DDim({-1});
}
......@@ -129,7 +134,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (scope) {
ss << "[" << GetDims(*scope, input.second[i]) << "]";
ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
}
if (i != input.second.size() - 1) {
......@@ -149,7 +154,7 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (scope) {
ss << "[" << GetDims(*scope, output.second[i]) << "]";
ss << "[" << GetDims(*scope, output.second[i], true) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")";
}
if (i != output.second.size() - 1) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include <string>
#include <tuple>
#include <vector>
#ifdef PADDLE_WITH_CUDA
......@@ -41,6 +42,8 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
};
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
......@@ -97,14 +100,9 @@ ParallelExecutor::ParallelExecutor(
allow_op_delay));
// Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) {
for (auto *var : main_program.Block(0).AllVars()) {
if (scope->FindVar(var->Name()) != nullptr) {
continue;
}
InitializeVariable(scope->Var(var->Name()), var->GetType());
}
for (auto *var : main_program.Block(0).AllVars()) {
member_->var_types_.emplace_back(var->Name(), var->GetType(),
var->Persistable());
}
}
......@@ -163,9 +161,42 @@ void ParallelExecutor::Run(
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
platform::RecordBlock b(0);
SplitTensorToPlaces(feed_tensors);
// Create local scopes.
for (auto &scope : member_->local_scopes_) {
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &name_type_pair : member_->var_types_) {
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
continue;
}
if (std::get<2>(name_type_pair)) { // Persistable
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
} else {
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
}
}
}
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
// Wait All computational streams
for (auto p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : member_->local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
local_scope = nullptr;
}
}
void ParallelExecutor::SplitTensorToPlaces(
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <string>
#include <vector>
......@@ -34,7 +35,10 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
if (n == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size();
......
......@@ -161,6 +161,7 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
......
......@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <thread>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/listen_and_serv_op.h"
......@@ -88,8 +89,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
......@@ -97,18 +99,25 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place);
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
if (blkid != prefetch_block->ID()) {
block_list.push_back(blkid);
}
}
auto prepared = executor.Prepare(*program, block_list);
auto optimize_prepared = executor.Prepare(*program, block_list);
// Insert placeholder for block0 which holds current op itself.
prepared.insert(prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......@@ -166,16 +175,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
parallel_blkids.push_back(1);
double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
&recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
if (blkid != prefetch_block->ID()) {
if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
parallel_blkids.push_back(blkid);
}
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
&recv_scope);
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
program, &recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
......@@ -211,6 +222,8 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -27,6 +28,7 @@ namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
......
......@@ -78,6 +78,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future>
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
......@@ -50,8 +50,8 @@ class PrefetchOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << "to get "
<< outs[i] << "back";
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i],
outs[i]);
} else {
......@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out",
"(SelectedRows) result "
"(LoDTensor) result "
"to be fetched from parameter server")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -37,11 +37,11 @@ namespace m = paddle::operators::math;
std::unique_ptr<f::OperatorBase> listen_and_serv_op;
int selected_port;
void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
void InitTensorsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place);
for (int i = 0; i < 2; ++i) {
auto var_name = paddle::string::Sprintf("x%d", i);
auto var = scope.Var(var_name);
auto var = scope->Var(var_name);
auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize({10, 10});
float *expect = tensor->mutable_data<float>(place);
......@@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
}
}
auto out_var = scope.Var("Out");
auto out_var = scope->Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate
}
void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
void InitSelectedRowsInScope(const p::CPUPlace &place, f::Scope *scope) {
p::CPUDeviceContext ctx(place);
int64_t height = 10;
int64_t row_numel = 10;
m::SetConstant<p::CPUDeviceContext, float> set_one;
// init x0
std::vector<int64_t> rows0{0, 4, 7};
auto x0_var = scope.Var("x0");
auto x0_var = scope->Var("x0");
auto x0 = x0_var->GetMutable<f::SelectedRows>();
x0->set_rows(rows0);
x0->set_height(height);
......@@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
// init x1
std::vector<int64_t> rows1{2, 9};
auto x1_var = scope.Var("x1");
auto x1_var = scope->Var("x1");
auto x1 = x1_var->GetMutable<f::SelectedRows>();
x1->set_rows(rows1);
x1->set_height(height);
......@@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
f::make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), place);
set_one(ctx, x1_value, 1.0);
auto out_var = scope.Var("Out");
auto out_var = scope->Var("Out");
auto out = out_var->GetMutable<f::SelectedRows>();
auto out_value = out->mutable_value();
out->set_height(height);
......@@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) {
f::Scope scope;
p::CPUPlace place;
if (is_sparse) {
InitSelectedRowsInScope(scope, place);
InitSelectedRowsInScope(place, &scope);
} else {
InitTensorsInScope(scope, place);
InitTensorsInScope(place, &scope);
}
// sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program;
const auto &root_block = program.Block(0);
auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
......@@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
LOG(INFO) << "selected port before run " << selected_port;
......@@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) {
// local net
f::Scope scope;
p::CPUPlace place;
InitTensorsInScope(scope, place);
InitTensorsInScope(place, &scope);
// create rpc client var
scope.Var("RPC_CLIENT_VAR");
......@@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) {
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place);
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>(
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future>
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
......@@ -36,7 +36,7 @@ class SendVarsOp : public framework::OperatorBase {
auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_sent");
int sync_send = Attr<int>("sync_send");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......
......@@ -35,8 +35,8 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
// TODO(qijun): check dimensions of Param and Grad at complie
// and run time.
// TODO(qijun): check dimensions of Param and Grad at compile
// and runtime.
ctx->SetOutputDim("ParamOut", param_dim);
}
......
......@@ -48,11 +48,11 @@ class SplitIdsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR);
auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
}
};
......@@ -60,8 +60,9 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(out_var)->SetType(input_var->GetType());
}
}
};
......@@ -73,4 +74,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>);
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
......@@ -24,35 +24,63 @@ namespace operators {
template <typename DeviceContext, typename T>
class SplitIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel");
}
auto& ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T* ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
const auto *ids_var = ctx.InputVar("Ids");
if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto *shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto *shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
}
}
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), "");
const T *ids = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
// get rows for outputs
for (auto &id : ids_rows) {
size_t shard_id = static_cast<size_t>(id) % shard_num;
outs[shard_id]->mutable_rows()->push_back(id);
}
// create tensor for each shard and send to parameter server
for (size_t i = 0; i < out_ids.size(); ++i) {
auto* shard_t = outs[i];
std::vector<T> ids = out_ids[i];
auto* shard_data = shard_t->mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
for (size_t i = 0; i < ids.size(); ++i) {
shard_data[i] = ids[i];
int64_t row_width = ids_dims[1];
for (auto &out : outs) {
out->set_height(ids_selected_rows->height());
framework::DDim ddim = framework::make_ddim(
{static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (size_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
row_width * sizeof(T));
}
}
}
}
......
......@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
......@@ -37,7 +39,10 @@ class SumOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
if (N == 1) {
VLOG(3) << "Warning: sum have only one input, may waste memory";
}
framework::DDim in_dim({0});
for (auto& x_dim : x_dims) {
......
......@@ -218,6 +218,7 @@ def fc(input,
def embedding(input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
......@@ -268,8 +269,11 @@ def embedding(input,
inputs={'Ids': input,
'W': w},
outputs={'Out': tmp},
attrs={'is_sparse': is_sparse,
'padding_idx': padding_idx})
attrs={
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'padding_idx': padding_idx
})
return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册