未验证 提交 74605fc2 编写于 作者: Z zhaocaibei123 提交者: GitHub

upgrade async distributed training in pscore (#37515)

* test

* test

* rm test

* update

* update

* update

* add unittest

* update

* update save
上级 5607bcf2
......@@ -570,8 +570,6 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
......@@ -596,8 +594,6 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/pscore/distributed_push_sparse_op.h"
namespace paddle {
namespace operators {
constexpr int64_t kNoPadding = -1;
class DistributedPushSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("Ids"), true,
platform::errors::InvalidArgument(
"Input(Ids) of PushSparseOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true,
platform::errors::InvalidArgument(
"Output(Outs) of PushSparseOp should not be null."));
auto ids_dims = ctx->GetInputsDim("Ids");
for (auto &ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
platform::errors::InvalidArgument(
"The dimension of the 'Ids' tensor must be 2."));
}
// for fluid.embedding
auto push_sparse_version =
ctx->Attrs().Get<std::string>("push_sparse_version");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class DistributedPushSparseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.")
.AsDuplicable();
AddInput("Shows",
"(LoDTensor) Shows's type should be LoDTensor"
"THe shows default to be 1.")
.AsDuplicable();
AddInput("Clicks",
"(LoDTensor) Clicks's type should be LoDTensor"
"THe clicks usually equal to label.")
.AsDuplicable();
AddOutput("Outputs",
"(LoDTensor) The lookup results, which have the same type as W.")
.AsDuplicable();
AddAttr<int>("table_id", "sparse table id").SetDefault(0);
AddAttr<int>("size", "embedding size").SetDefault(8);
AddAttr<bool>("is_distributed",
"(boolean, default false) distributed lookup table.")
.SetDefault(false);
AddAttr<std::string>(
"push_sparse_version",
"(string, default push_sparse) "
"To distinguish between different versions of embedding OP")
.SetDefault(std::string("push_sparse"));
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training.")
.SetDefault(false);
AddComment(R"DOC(
Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_push_sparse, ops::DistributedPushSparseOp,
ops::DistributedPushSparseOpMaker);
REGISTER_OP_CPU_KERNEL(
distributed_push_sparse,
ops::DistributedPushSparseKernel<paddle::platform::CPUDeviceContext, float>,
ops::DistributedPushSparseKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pscore/distributed_push_sparse_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
distributed_push_sparse,
ops::DistributedPushSparseKernel<plat::CUDADeviceContext, float>,
ops::DistributedPushSparseKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/fleet.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DistributedPushSparseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &scope = context.scope();
auto padding_idx = context.Attr<int64_t>("padding_idx");
auto table_id = context.Attr<int>("table_id");
auto emb_dim = context.Attr<int>("size");
VLOG(1) << "push_sparse.h::emb_dim: " << emb_dim;
bool is_test = context.Attr<bool>("is_test");
auto inputs = context.MultiInput<framework::LoDTensor>("Ids");
auto shows = context.Input<framework::LoDTensor>("Shows");
auto clks = context.Input<framework::LoDTensor>("Clicks");
auto outputs = context.MultiOutput<framework::LoDTensor>("Outputs");
auto fleet = distributed::FleetWrapper::GetInstance();
if (platform::is_cpu_place(context.GetPlace())) {
fleet->PushSparseFromTensorAsync(static_cast<uint64_t>(table_id), emb_dim,
static_cast<uint64_t>(padding_idx),
context.GetPlace(), &inputs, shows, clks,
&outputs);
} else {
auto inputs_variable = context.MultiInputVar("Ids");
auto outputs_variable = context.MultiOutputVar("Outputs");
auto inputs_name = context.InputNames("Ids");
auto outputs_name = context.OutputNames("Outputs");
auto cpu_place = platform::CPUPlace();
framework::Scope *tmp_scope = scope.NewTmpScope().release();
std::vector<const framework::LoDTensor *> tmp_input_vec;
auto input_var_size = inputs_variable.size();
std::vector<framework::LoDTensor *> tmp_output_vec;
auto output_var_size = outputs_variable.size();
// create temp input
for (size_t idx = 0; idx < input_var_size; ++idx) {
framework::Variable *tmp_input_var = tmp_scope->Var(inputs_name[idx]);
framework::LoDTensor *tmp_input_tensor =
tmp_input_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(inputs_variable[idx]->Get<framework::LoDTensor>(),
cpu_place, context.device_context(),
tmp_input_tensor);
tmp_input_vec.push_back(tmp_input_tensor);
}
// create temp output
for (size_t idx = 0; idx < output_var_size; ++idx) {
framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]);
framework::LoDTensor *tmp_output_tensor =
tmp_output_var->GetMutable<framework::LoDTensor>();
tmp_output_tensor->Resize(outputs[idx]->dims());
tmp_output_vec.push_back(tmp_output_tensor);
}
// use fleet->PullSparse
fleet->PullSparseToTensorSync(static_cast<uint64_t>(table_id), emb_dim,
static_cast<uint64_t>(padding_idx),
cpu_place, !is_test, &tmp_input_vec,
&tmp_output_vec);
// cp temp to origin
for (size_t idx = 0; idx < output_var_size; ++idx) {
framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]);
framework::LoDTensor *tmp_output_tensor =
tmp_output_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(
*tmp_output_tensor, context.GetPlace(), context.device_context(),
outputs_variable[idx]->GetMutable<framework::LoDTensor>());
}
delete tmp_scope;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/fleet.h"
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -42,14 +43,15 @@ class SendOp : public framework::OperatorBase {
const platform::Place& place) const override {
auto ins = Inputs("X");
// auto is_sparse = Attr<int>("is_sparse");
// auto table_id = Attr<int>("table_id");
auto table_id = Attr<int>("table_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto* communicator = paddle::distributed::Communicator::GetInstance();
if (communicator->Check(send_varnames)) {
communicator->Send(ins, scope);
}
auto fleet = paddle::distributed::FleetWrapper::GetInstance();
std::vector<::std::future<int32_t>> status;
// Note: only send push_dense now!
// communicator->Send(ins, scope) can be used to push_sparse or push_dense
fleet->PushDenseVarsAsync(scope, table_id, ins, &status, 0, -1);
// auto fleet = paddle::distributed::FleetWrapper::GetInstance();
// if (is_sparse == 0) {
......
......@@ -76,7 +76,9 @@ void BindDistFleetWrapper(py::module* m) {
.def("stop_server", &FleetWrapper::StopServer)
.def("stop_worker", &FleetWrapper::FinalizeWorker)
.def("barrier", &FleetWrapper::BarrierWithTable)
.def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable);
.def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable)
.def("create_client2client_connection",
&FleetWrapper::CreateClient2ClientConnection);
}
void BindPSHost(py::module* m) {
......@@ -159,8 +161,11 @@ void BindDistCommunicator(py::module* m) {
.def("push_sparse_param", &Communicator::RpcSendSparseParam)
.def("is_running", &Communicator::IsRunning)
.def("init_params", &Communicator::InitParams)
.def("pull_dense", &Communicator::PullDense);
// .def("recv", &Communicator::RecvNoBarrier);
.def("pull_dense", &Communicator::PullDense)
.def("create_client_to_client_connection",
&Communicator::CreateC2CConnection)
.def("get_client_info", &Communicator::GetClientInfo)
.def("set_clients", &Communicator::SetClients);
}
void BindHeterClient(py::module* m) {
......
......@@ -48,6 +48,7 @@ from .fleet import BoxPSDataset # noqa: F401
from .entry_attr import ProbabilityEntry # noqa: F401
from .entry_attr import CountFilterEntry # noqa: F401
from .entry_attr import ShowClickEntry # noqa: F401
from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401
......@@ -69,6 +70,7 @@ __all__ = [ # noqa
"QueueDataset",
"split",
"CountFilterEntry",
"ShowClickEntry",
"get_world_size",
"get_group",
"all_gather",
......
......@@ -137,3 +137,45 @@ class CountFilterEntry(EntryAttr):
def _to_attr(self):
return ":".join([self._name, str(self._count_filter)])
class ShowClickEntry(EntryAttr):
"""
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
sparse_feature_dim = 1024
embedding_size = 64
shows = paddle.static.data(name='show', shape=[1], dtype='int64')
clicks = paddle.static.data(name='click', shape=[1], dtype='int64')
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
entry = paddle.distributed.ShowClickEntry("show", "click")
emb = paddle.static.nn.sparse_embedding(
input=input,
size=[sparse_feature_dim, embedding_size],
is_test=False,
entry=entry,
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
initializer=paddle.nn.initializer.Uniform()))
"""
def __init__(self, show_name, click_name):
super(ShowClickEntry, self).__init__()
if not isinstance(show_name, str) or not isinstance(click_name, str):
raise ValueError("show_name click_name must be a str")
self._name = "show_click_entry"
self._show_name = show_name
self._click_name = click_name
def _to_attr(self):
return ":".join([self._name, self._show_name, self._click_name])
......@@ -392,6 +392,38 @@ class DistributedStrategy(object):
"""
return get_msg_dict(self.strategy.trainer_desc_configs)
@property
def adam_d2sum(self):
"""
set adam_d2sum
Default value: True
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
role_maker = fleet.PaddleCloudRoleMaker()
fleet.init(role_maker)
strategy = fleet.DistributedStrategy()
strategy.adam_d2sum = True # by default this is True
# code block for defining loss and local optimizer
# sgd = fleet.distributed_optimizer(optimizer, strategy)
"""
return self.strategy.adam_d2sum
@adam_d2sum.setter
@is_strict_auto
def adam_d2sum(self, flag):
if isinstance(flag, bool):
self.strategy.adam_d2sum = flag
else:
raise ValueError(
"The type of `flag` is invalid, expected type is bool, but received {}".
format(type(flag)))
@trainer_desc_configs.setter
@is_strict_auto
def trainer_desc_configs(self, configs):
......@@ -399,6 +431,62 @@ class DistributedStrategy(object):
"trainer_desc_configs")
assign_configs_value(self.strategy.trainer_desc_configs, configs)
@property
def fs_client_param(self):
"""
Set fs client configurations.
**Notes**:
uri(str): the uri of fs client
user(str): the user_name of fs client
passwd(str): the passwd of fs client
hadoop_bin(str):
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
role_maker = fleet.PaddleCloudRoleMaker()
fleet.init(role_maker)
strategy = fleet.DistributedStrategy()
configs = {"uri": "xxx", "user": "xxx", passwd: "xxx"}
strategy.fs_client_param = configs
# code block for defining loss and local optimizer
# sgd = fleet.distributed_optimizer(optimizer, strategy)
"""
return self.strategy.fs_client_param
@fs_client_param.setter
@is_strict_auto
def fs_client_param(self, configs):
check_configs_key(self.strategy.fs_client_param, configs,
"fs_client_param")
assign_configs_value(self.strategy.fs_client_param, configs)
@property
def sparse_table_configs(self):
return self.strategy.downpour_table_param
@sparse_table_configs.setter
@is_strict_auto
def sparse_table_configs(self, configs):
from google.protobuf.descriptor import FieldDescriptor
table_param = self.strategy.downpour_table_param
def set_table_config(msg, config_name, configs):
for field in msg.DESCRIPTOR.fields:
name = config_name + "." + field.name
if field.type == FieldDescriptor.TYPE_MESSAGE:
print("message:", name)
set_table_config(getattr(msg, field.name), name, configs)
else:
print("not message:", name)
if name not in configs:
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
getattr(msg, field.name).extend(configs[name])
else:
setattr(msg, field.name, configs[name])
set_table_config(table_param, "table_parameters", configs)
@property
def amp(self):
"""
......
......@@ -24,6 +24,7 @@ from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready
import paddle.distributed.fleet as fleet
__all__ = []
......@@ -49,7 +50,7 @@ def parse_table_class(varname, o_main_program):
if op.has_attr('table_class') and op.attr("table_class") != "none":
return op.attr('table_class')
else:
return "CommonSparseTable"
return "MemorySparseTable"
class Accessor:
......@@ -95,6 +96,11 @@ class CommonAccessor:
opt_input_map["adam"] = [("Param", None), ("Moment1", None),
("Moment2", None), ("Beta1Pow", 1),
("Beta2Pow", 1), ("LearningRate", 1)]
opt_input_map["adam_d2sum"] = [
("Param", None), ("D2Sum", None), ("G2Sum", None), ("Moment", None),
("MomentDecayRate", 1), ("AdaDecayRate", 1), ("AdaEpsilon", 1),
("LearningRate", 1)
]
opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)]
......@@ -105,6 +111,8 @@ class CommonAccessor:
opt_attr_map["naive_adagrad"] = []
opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")]
opt_init_map = {}
opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
......@@ -162,7 +170,7 @@ class CommonAccessor:
return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims,
compiled_strategy):
compiled_strategy, adam_d2sum):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_optimize_ops
param_name = compiled_strategy.grad_name_to_param_name[grad_name]
main_program, startup_program = compiled_strategy.get_origin_programs()
......@@ -187,6 +195,10 @@ class CommonAccessor:
self.trainer_num = compiled_strategy.get_trainers()
if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False')
adam_d2sum = False
print("adam_d2sum:", adam_d2sum)
if compiled_strategy.is_geo_mode():
param_varnames = self.opt_input_map["sum"]
attr_varnames = self.opt_attr_map["sum"]
......@@ -195,6 +207,10 @@ class CommonAccessor:
param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd"
elif adam_d2sum:
param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum"
else:
param_varnames = self.opt_input_map[oop.type]
attr_varnames = self.opt_attr_map[oop.type]
......@@ -202,17 +218,8 @@ class CommonAccessor:
for (formal_name, shape) in param_varnames:
params.append(formal_name)
if formal_name == "G2Sum":
dims.append(1)
initializer = "fill_constant&0"
initializers.append(initializer)
else:
param = main_program.global_block().vars[oop.input(formal_name)[
0]]
if formal_name == "LearningRate" and param.name != "learning_rate_0":
warnings.warn("will support decay soon")
param = main_program.global_block().vars["learning_rate_0"]
if self.accessor_class == "adam_d2sum":
#for dims
if shape is None:
if is_sparse:
shape = total_dims
......@@ -221,9 +228,51 @@ class CommonAccessor:
pserver_id)
dims.append(shape)
initializer = self.get_initializer_attr(param.name,
startup_program)
#for initializers
if formal_name == "Param" or formal_name == "LearningRate":
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
#TODO: for dense learning_rate, can be different from sparse lr
if formal_name == "LearningRate" and param.name != "learning_rate_0":
warnings.warn("will support decay soon")
param = main_program.global_block().vars[
"learning_rate_0"]
initializer = self.get_initializer_attr(param.name,
startup_program)
elif formal_name == "MomentDecayRate":
initializer = "fill_constant&0.99"
elif formal_name == "AdaDecayRate":
initializer = "fill_constant&0.9999"
elif formal_name == "AdaEpsilon":
initializer = "fill_constant&1.0e-8"
else:
initializer = "fill_constant&0"
initializers.append(initializer)
else:
if formal_name == "G2Sum":
dims.append(1)
initializer = "fill_constant&0"
initializers.append(initializer)
else:
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
if formal_name == "LearningRate" and param.name != "learning_rate_0":
warnings.warn("will support decay soon")
param = main_program.global_block().vars[
"learning_rate_0"]
if shape is None:
if is_sparse:
shape = total_dims
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
dims.append(shape)
initializer = self.get_initializer_attr(param.name,
startup_program)
initializers.append(initializer)
for (attr_varname, type_) in attr_varnames:
value = oop.attr(attr_varname)
......@@ -292,6 +341,7 @@ class Table:
self.accessor = None
self.common = None
self.tensor = None
self.accessor_proto = None
def to_string(self, indent):
table_str = "{}downpour_table_param {{{}\n{}}}"
......@@ -304,6 +354,14 @@ class Table:
attrs += "\n"
indent += 2
if self.accessor_proto is not None:
accessor_str = "{}accessor {{{}\n{}}}"
accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n"
return table_str.format(
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
attrs += self.accessor.to_string(indent)
attrs += "\n"
......@@ -431,6 +489,24 @@ class Worker:
return worker_str.format(workers_str)
class fsClient:
def __init__(self, proto):
self.proto = proto
self.uri = proto.uri
self.user = proto.user
self.passwd = proto.passwd
self.hadoop_bin = proto.hadoop_bin
def to_string(self):
from google.protobuf import text_format
proto_txt = text_format.MessageToString(self.proto)
if proto_txt:
fs_str = "fs_client_param {{\n{}}}"
return fs_str.format(proto_txt)
else:
return ""
class TheOnePSRuntime(RuntimeBase):
def __init__(self):
super(TheOnePSRuntime, self).__init__()
......@@ -533,7 +609,6 @@ class TheOnePSRuntime(RuntimeBase):
trainer_config = self.async_strategy.get_trainer_runtime_config()
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("worker: \n{}".format(proto_txt))
print("communicator send_ctx:")
......@@ -565,6 +640,25 @@ class TheOnePSRuntime(RuntimeBase):
self._communicator.init_with_ctx(send_ctx, dense_map, proto_txt,
string_hosts, fluid.global_scope())
import paddle.distributed.fleet as fleet
fleet.util.barrier()
info = self._communicator.get_client_info()
if isinstance(info, list) and len(info) > 0:
all_info = self.role_maker._all_gather(info[0])
# for unittest
if not isinstance(all_info, list):
warnings.warn("gloo may not initialize correctly")
all_info = [all_info]
self._communicator.set_clients(all_info)
# create_c2c_connection default param:
# pserver_timeout_ms=500000
# pserver_connect_timeout_ms=10000
# max_retry=3
self._communicator.create_client_to_client_connection()
print('create c2c connection done')
else:
print('cannot create c2c connection')
dist_strategy = self.context["valid_strategy"]
is_test = bool(int(os.getenv("TEST_MODE", "0")))
......@@ -577,7 +671,6 @@ class TheOnePSRuntime(RuntimeBase):
else:
init_params = dense_map
import paddle.distributed.fleet as fleet
if not is_test:
self._communicator.init_params(init_params)
fleet.util.barrier()
......@@ -632,7 +725,7 @@ class TheOnePSRuntime(RuntimeBase):
int(os.getenv("FLAGS_selected_xpus", "0"))))
return executor
def _get_fleet_proto(self, is_server, is_sync):
def _get_fleet_proto(self, is_server, is_sync, **kwargs):
def _build_merge_accessor(ctx):
accessor = Accessor()
accessor.accessor_class = "CommMergeAccessor"
......@@ -736,6 +829,7 @@ class TheOnePSRuntime(RuntimeBase):
tables = []
for idx, (name, ctx) in enumerate(send_ctx.items()):
print(" wxm python test send_ctx.items-->", idx, (name, ctx))
if ctx.is_tensor_table() or len(ctx.origin_varnames()) < 1:
continue
......@@ -755,18 +849,64 @@ class TheOnePSRuntime(RuntimeBase):
else:
table.table_class = parse_table_class(
common.table_name, self.origin_main_program)
table_proto = self.context[
"user_defined_strategy"].sparse_table_configs
table.shard_num = table_proto.shard_num
from google.protobuf import text_format
table.accessor_proto = text_format.MessageToString(
table_proto.accessor)
print('table proto:', table_proto)
if table.table_class == 'MemorySparseTable' and table.accessor_proto == '':
emb_dim = ctx.sections()[1]
table.shard_num = 1950
table.accessor_proto = 'accessor_class: "CtrCommonAccessor"\n' \
'embed_sgd_param {\n' \
' name: "SparseAdaGradSGDRule"\n' \
' adagrad {\n' \
' learning_rate: 0.05\n' \
' initial_g2sum: 3.0\n' \
' initial_range: 0.0001\n' \
' weight_bounds: -10.0\n' \
' weight_bounds: 10.0\n' \
' }\n' \
'}\n' \
'embedx_sgd_param {\n' \
' name: "SparseAdaGradSGDRule"\n' \
' adagrad {\n' \
' learning_rate: 0.05\n' \
' initial_g2sum: 3.0\n' \
' initial_range: 0.0001\n' \
' weight_bounds: -10.0\n' \
' weight_bounds: 10.0\n' \
' }\n' \
'}\n' \
'fea_dim: ' + str(emb_dim+2) + '\n' \
'embedx_dim: ' + str(emb_dim-1) + '\n' \
'embedx_threshold: 10\n' \
'ctr_accessor_param {\n' \
' nonclk_coeff: 0.1\n' \
' click_coeff: 1.0\n' \
' base_threshold: 1.5\n' \
' delta_threshold: 0.25\n' \
' delta_keep_days: 16.0\n' \
' show_click_decay_rate: 0.98\n' \
' delete_threshold: 0.8\n' \
' delete_after_unseen_days: 30.0\n' \
' ssd_unseenday_threshold: 1\n' \
'}'
else:
table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable"
table.shard_num = 256
common.table_name = "MergedDense"
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0],
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0],
self.compiled_strategy)
self.compiled_strategy, adam_d2sum)
if ctx.is_sparse():
common.parse_entry(common.table_name,
......@@ -779,8 +919,9 @@ class TheOnePSRuntime(RuntimeBase):
table.common = common
accessor = _build_merge_accessor(ctx)
table.accessor = accessor
if table.table_class != 'MemorySparseTable':
accessor = _build_merge_accessor(ctx)
table.accessor = accessor
tables.append(table)
tensor_table_dict = self.compiled_strategy.get_tensor_table_dict()
......@@ -828,6 +969,9 @@ class TheOnePSRuntime(RuntimeBase):
trainers += len(self.role_maker._get_heter_worker_endpoints())
server = self._get_fleet_proto(is_server=True, is_sync=is_sync)
proto_txt = str(server)
fs_client = fsClient(self.context["user_defined_strategy"]
.fs_client_param)
proto_txt = proto_txt + "\n" + fs_client.to_string()
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
......@@ -924,10 +1068,14 @@ class TheOnePSRuntime(RuntimeBase):
for id, names in context.items():
if names[0] not in distributed_varnames:
# only save sparse param to local
self._worker.recv_and_save_model(id, dirname)
try:
self._worker.recv_and_save_model(id, dirname)
except:
pass
# save sparse & distributed param on server
self._worker.save_one_model(id, dirname, mode)
values.extend(names)
# self._worker.save_all_model(dirname, mode)
return values
def _save_distributed_persistables(self,
......@@ -951,6 +1099,7 @@ class TheOnePSRuntime(RuntimeBase):
recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)
self._communicator.pull_dense(denses)
saved_varnames = sparse_varnames
......@@ -1004,8 +1153,9 @@ class TheOnePSRuntime(RuntimeBase):
)
# Todo(MrChengmo): Save optimizer status
self._save_distributed_persistables(executor, dirname, main_program,
mode)
# self._save_distributed_persistables(executor, dirname, main_program,
# mode)
self._worker.save_all_model(dirname, mode)
def _ps_inference_save_inference_model(self,
executor,
......@@ -1046,10 +1196,45 @@ class TheOnePSRuntime(RuntimeBase):
infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_basename = "__model__"
model_basename = os.path.join(dirname, model_basename)
model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename)
sparses = self.compiled_strategy.get_the_one_recv_context(
is_dense=False,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
print("the one ps sparses:", sparses)
sparse_names = []
for id, name in sparses.items():
sparse_names.extend(name)
print("the one ps sparse names:", sparse_names)
denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
self._communicator.pull_dense(denses)
generate_vars = self.context[
"user_defined_strategy"].trainer_desc_configs["stat_var_names"]
generate_vars = [var for var in generate_vars]
remaining_vars = list(
filter(
TheOnePSRuntime.__exclude_vars(generate_vars + sparse_names),
infer_program.list_vars()))
print("remain_vars:", [var.name for var in remaining_vars])
for var in remaining_vars:
tensor = var.get_value()
paddle.save(
tensor,
os.path.join(model_path, var.name),
use_binary_format=True)
self._ps_inference_save_persistables(executor, dirname, infer_program,
mode)
......@@ -1073,8 +1258,10 @@ class TheOnePSRuntime(RuntimeBase):
values.extend(names)
return values
def _load_distributed_persistables(self, dirname, main_program=None,
mode=0):
def _ps_inference_load_inference_model(self,
dirname,
mode=0,
main_program=None):
if main_program is None:
main_program = self.compiled_strategy.get_origin_ps_main_program()
......@@ -1106,17 +1293,27 @@ class TheOnePSRuntime(RuntimeBase):
TheOnePSRuntime.__exclude_vars(loaded_varnames),
main_program.list_vars()))
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
import paddle
for var in remaining_vars:
if var.name not in recv_dense_varnames:
continue
tensor = paddle.load(os.path.join(dirname, var.name))
tensor = paddle.load(os.path.join(model_path, var.name))
var.set_value(tensor)
self._communicator.init_params(denses)
def _load_distributed_persistables(self, path, mode):
self._worker.load_model(path, mode)
def load_model(self, path, mode):
self._load_distributed_persistables(path, mode=mode)
if mode == 0 or mode == 3:
self._load_distributed_persistables(path, mode)
else:
self._ps_inference_load_inference_model(path, mode)
def _shrink(self, threshold):
import paddle.distributed.fleet as fleet
......
......@@ -99,6 +99,19 @@ class Communicator(object):
self.send_ctx_ = send_ctx
self.recv_ctx_ = recv_ctx
def create_client_to_client_connection(self,
pserver_timeout_ms=500000,
pserver_connect_timeout_ms=10000,
max_retry=3):
self.communicator_.create_client_to_client_connection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry)
def get_client_info(self):
return self.communicator_.get_client_info()
def set_clients(self, host_list):
self.communicator_.set_clients(host_list)
def start(self):
"""
Start communicator. Should call before training process.
......
......@@ -968,7 +968,7 @@ def sparse_embedding(input,
padding_idx=None,
is_test=False,
entry=None,
table_class="CommonSparseTable",
table_class="MemorySparseTable",
param_attr=None,
dtype='float32'):
r"""
......@@ -1100,18 +1100,21 @@ def sparse_embedding(input,
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx)
if table_class not in ["CommonSparseTable", "SSDSparseTable"]:
if table_class not in [
"CommonSparseTable", "SSDSparseTable", "MemorySparseTable"
]:
raise ValueError(
"table_class must be in [CommonSparseTable, SSDSparseTable]")
"table_class must be in [CommonSparseTable, SSDSparseTable, MemorySparseTable]"
)
entry_str = "none"
if entry is not None:
if entry.__class__.__name__ not in [
"ProbabilityEntry", "CountFilterEntry"
"ProbabilityEntry", "CountFilterEntry", "ShowClickEntry"
]:
raise ValueError(
"entry must be instance in [paddle.distributed.ProbabilityEntry, paddle.distributed.CountFilterEntry]"
"entry must be instance in [paddle.distributed.ProbabilityEntry, paddle.distributed.CountFilterEntry, paddle.distributed.ShowClickEntry]"
)
entry_str = entry._to_attr()
......
......@@ -43,6 +43,10 @@ OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W"
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
......@@ -98,9 +102,14 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
trainer_id = config.get_role_id()
send_ctx = config.get_the_one_send_context(
split_dense_table=config.is_heter_ps_mode)
w_2_table_id = {}
emb_size = {}
def _get_pull_sparse_ops(_program):
pull_sparse_ops = {}
pull_sparse_ids = {}
push_sparse_ops = {}
ops = {}
for op in _program.global_block().ops:
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
......@@ -111,7 +120,18 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
ops = pull_sparse_ops.get(param_name, [])
ops.append(op)
pull_sparse_ops[param_name] = ops
return pull_sparse_ops
ids = pull_sparse_ids.get(param_name, [])
ids.append(op.input("Ids")[0])
pull_sparse_ids[param_name] = ids
for op in _program.global_block().ops:
if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys():
param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0]
if param_name in pull_sparse_ids and op.input("Ids")[
0] in pull_sparse_ids[param_name]:
ops = push_sparse_ops.get(param_name, [])
ops.append(op)
push_sparse_ops[param_name] = ops
return pull_sparse_ops, push_sparse_ops
def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu):
def dag_check_up_and_reorder(program, inputs, outputs):
......@@ -218,6 +238,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
emb_size[param] = w.shape[1]
grad_name = config.param_name_to_grad_name[w.name]
......@@ -231,6 +252,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
raise ValueError(
"can not find suitable sparse table, please check")
w_2_table_id[param] = table_id
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
......@@ -263,7 +285,6 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
outputs_idxs[out_id])
if min(outputs_idxs) - max(inputs_idxs) >= 1:
if max(inputs_idxs) == -1:
distributed_idx = min(op_idxs)
else:
......@@ -313,8 +334,123 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
"op_device": op_device
})
pull_sparse_ops = _get_pull_sparse_ops(program)
def _push_sparse_fuse(_program, push_sparse_ops, use_ps_gpu):
if use_ps_gpu:
# in ps_gpu_pass
return
if len(push_sparse_ops) == 0:
return
show = None
clk = None
use_entry = False
for param, ops in push_sparse_ops.items():
op_first = ops[0]
break
print(op_first)
if op_first.has_attr("entry"):
entry = op_first.attr("entry")
entry = entry.split(':')
if len(entry) == 3 and entry[0] == 'show_click_entry':
show_var_name = entry[1]
click_var_name = entry[2]
if show_var_name in program.global_block(
).vars and click_var_name in program.global_block().vars:
show = program.global_block().vars[show_var_name]
clk = program.global_block().vars[click_var_name]
use_entry = True
else:
warnings.warn(
'ShowClickEntry configured, but cannot find show/click var, will not use'
)
if not use_entry:
print('ShowClickEntry not configured, will not use')
show = program.global_block().create_var(
name="show",
dtype=core.VarDesc.VarType.INT64,
persistable=False,
stop_gradient=True)
program.global_block()._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': show},
attrs={
'shape': [1],
'dtype': show.dtype,
'value': 1,
#OP_ROLE_KEY: OpRole.Forward
})
clk = program.global_block().create_var(
name="clk",
dtype=core.VarDesc.VarType.INT64,
persistable=False,
stop_gradient=True)
program.global_block()._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': clk},
attrs={
'shape': [1],
'dtype': clk.dtype,
'value': 0,
#OP_ROLE_KEY: OpRole.Forward
})
for param, ops in push_sparse_ops.items():
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].output("W@GRAD")[0]]
table_id = w_2_table_id[param]
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
program.global_block().vars[op.input("Out@GRAD")[0]]
for op in ops
]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
# if use_ps_gpu:
# program.global_block().append_op(
# type="push_box_sparse",
# inputs={"Ids": inputs,
# 'Out': outputs},
# outputs={"Out": outputs},
# attrs={
# "size": w.shape[1],
# "is_distributed": True,
# "is_sparse": True
# })
# else:
program.global_block().append_op(
type="distributed_push_sparse",
inputs={
"Ids": inputs,
'W': w,
"Outputs": outputs,
"Shows": show,
"Clicks": clk
},
outputs={"Outputs": outputs},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"size": emb_size[param]
})
pull_sparse_ops, push_sparse_ops = _get_pull_sparse_ops(program)
_pull_sparse_fuse(program, pull_sparse_ops, use_ps_gpu)
_push_sparse_fuse(program, push_sparse_ops, use_ps_gpu)
return program
......@@ -367,6 +503,8 @@ def append_send_ops_pass(program, config):
split_dense_table=config.is_heter_ps_mode)
for merged_name, send in sends.items():
if send.is_sparse():
continue
is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse
dummys.append(
......
......@@ -156,6 +156,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
return avg_cost
def check_model_right(self, dirname):
dirname = dirname + '/dnn_plugin/'
model_filename = os.path.join(dirname, "__model__")
with open(model_filename, "rb") as f:
......
......@@ -98,11 +98,13 @@ class TestDistCTR2x2(FleetDistRunnerBase):
else:
raise ValueError("error initializer code: {}".format(initializer))
entry = paddle.distributed.ShowClickEntry("show", "click")
dnn_layer_dims = [128, 64, 32]
dnn_embedding = fluid.contrib.layers.sparse_embedding(
input=dnn_data,
size=[dnn_input_dim, dnn_layer_dims[0]],
is_test=inference,
entry=entry,
param_attr=fluid.ParamAttr(
name="deep_embedding", initializer=init))
dnn_pool = fluid.layers.sequence_pool(
......@@ -123,6 +125,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
input=lr_data,
size=[lr_input_dim, 1],
is_test=inference,
entry=entry,
param_attr=fluid.ParamAttr(
name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)))
......
......@@ -19,7 +19,7 @@ paddle.enable_static()
import unittest
import paddle.fluid as fluid
from paddle.distributed import ProbabilityEntry, CountFilterEntry
from paddle.distributed import ProbabilityEntry, CountFilterEntry, ShowClickEntry
class EntryAttrChecks(unittest.TestCase):
......@@ -51,6 +51,11 @@ class EntryAttrChecks(unittest.TestCase):
with self.assertRaises(ValueError):
counter2 = CountFilterEntry(-1)
def showclick_entry(self):
showclick = ShowClickEntry("show", "click")
ss = showclick._to_attr()
self.assertEqual("show_click_entry:show:click", ss)
def spaese_layer(self):
prog = fluid.Program()
scope = fluid.core.Scope()
......@@ -97,6 +102,9 @@ class TestEntryAttrs(EntryAttrChecks):
def test_counter(self):
self.countfilter_entry()
def test_showclick(self):
self.showclick_entry()
def test_spaese_embedding_layer(self):
self.spaese_layer()
......
......@@ -76,6 +76,7 @@ class TestFleetBase(unittest.TestCase):
fleet.fleet.save(dirname="/tmp")
fleet.load_model(path="/tmp", mode=0)
fleet.load_model(path="/tmp", mode=1)
self.assertRaises(
Exception,
......@@ -94,6 +95,15 @@ class TestFleetBase(unittest.TestCase):
executor=exe,
main_program=compiled_prog)
self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='afs:/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor=exe,
main_program=compiled_prog)
self.assertRaises(
Exception, fleet.save_persistables, executor=pe, dirname='/tmp/')
......
......@@ -255,6 +255,25 @@ class TestStrategyConfig(unittest.TestCase):
strategy.a_sync_configs = configs
self.assertEqual(strategy.a_sync_configs["k_steps"], 1000)
def test_sparse_table_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {
"table_parameters.accessor.embed_sgd_param.adagrad.learning_rate":
0.05
}
strategy.sparse_table_configs = configs
self.assertEqual(strategy.sparse_table_configs.accessor.embed_sgd_param.
adagrad.learning_rate, 0.05)
strategy.adam_d2sum = True
self.assertEqual(strategy.adam_d2sum, True)
strategy.fs_client_param = {
"uri": "123",
"user": "456",
"passwd": "789",
"hadoop_bin": "hadoop"
}
self.assertEqual(strategy.fs_client_param.user, "456")
def test_trainer_desc_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册