From 74605fc2795d6a366de75513d388b225e2a9eaba Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Fri, 26 Nov 2021 18:56:31 +0800 Subject: [PATCH] upgrade async distributed training in pscore (#37515) * test * test * rm test * update * update * update * add unittest * update * update save --- paddle/fluid/distributed/fleet.cc | 4 - .../pscore/distributed_push_sparse_op.cc | 132 +++++++++ .../pscore/distributed_push_sparse_op.cu.cc | 23 ++ .../pscore/distributed_push_sparse_op.h | 104 +++++++ paddle/fluid/operators/pscore/send_op.cc | 12 +- paddle/fluid/pybind/fleet_py.cc | 11 +- python/paddle/distributed/__init__.py | 2 + python/paddle/distributed/entry_attr.py | 42 +++ .../fleet/base/distributed_strategy.py | 88 ++++++ .../distributed/fleet/runtime/the_one_ps.py | 257 ++++++++++++++++-- python/paddle/fluid/communicator.py | 13 + python/paddle/fluid/contrib/layers/nn.py | 13 +- .../fleet/parameter_server/ir/trainer_pass.py | 144 +++++++++- .../fluid/tests/unittests/dist_fleet_ctr.py | 1 + .../dist_fleet_sparse_embedding_ctr.py | 3 + .../fluid/tests/unittests/test_entry_attr.py | 10 +- .../tests/unittests/test_fleet_base_2.py | 10 + .../test_fleet_distributed_strategy.py | 19 ++ 18 files changed, 837 insertions(+), 51 deletions(-) create mode 100644 paddle/fluid/operators/pscore/distributed_push_sparse_op.cc create mode 100644 paddle/fluid/operators/pscore/distributed_push_sparse_op.cu.cc create mode 100644 paddle/fluid/operators/pscore/distributed_push_sparse_op.h diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index 871e503ca4..ba614179b3 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -570,8 +570,6 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; - sleep(sleep_seconds_before_fail_exit_); - exit(-1); } } @@ -596,8 +594,6 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { int32_t feasign_cnt = ret.get(); if (feasign_cnt == -1) { LOG(ERROR) << "save model failed"; - sleep(sleep_seconds_before_fail_exit_); - exit(-1); } } diff --git a/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc new file mode 100644 index 0000000000..3a1e2ea786 --- /dev/null +++ b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/pscore/distributed_push_sparse_op.h" + +namespace paddle { +namespace operators { + +constexpr int64_t kNoPadding = -1; + +class DistributedPushSparseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInputs("Ids"), true, + platform::errors::InvalidArgument( + "Input(Ids) of PushSparseOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true, + platform::errors::InvalidArgument( + "Output(Outs) of PushSparseOp should not be null.")); + + auto ids_dims = ctx->GetInputsDim("Ids"); + + for (auto &ids_dim : ids_dims) { + PADDLE_ENFORCE_EQ(ids_dim.size(), 2, + platform::errors::InvalidArgument( + "The dimension of the 'Ids' tensor must be 2.")); + } + + // for fluid.embedding + auto push_sparse_version = + ctx->Attrs().Get("push_sparse_version"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class DistributedPushSparseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Ids", + "(LoDTensor) Ids's type should be LoDTensor" + "THe ids to be looked up in W.") + .AsDuplicable(); + + AddInput("Shows", + "(LoDTensor) Shows's type should be LoDTensor" + "THe shows default to be 1.") + .AsDuplicable(); + + AddInput("Clicks", + "(LoDTensor) Clicks's type should be LoDTensor" + "THe clicks usually equal to label.") + .AsDuplicable(); + + AddOutput("Outputs", + "(LoDTensor) The lookup results, which have the same type as W.") + .AsDuplicable(); + + AddAttr("table_id", "sparse table id").SetDefault(0); + AddAttr("size", "embedding size").SetDefault(8); + + AddAttr("is_distributed", + "(boolean, default false) distributed lookup table.") + .SetDefault(false); + + AddAttr( + "push_sparse_version", + "(string, default push_sparse) " + "To distinguish between different versions of embedding OP") + .SetDefault(std::string("push_sparse")); + + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); + AddAttr("dtype", + "(int, default 5 (FP32)) " + "Output data type") + .SetDefault(framework::proto::VarType::FP32); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training.") + .SetDefault(false); + + AddComment(R"DOC( +Lookup Tablel Prefetch Operator. +This operator is used to perform lookup on parameter W, +then concatenated into a sparse tensor. +The type of Ids(Input) is SelectedRows, the rows of Ids contains +the ids to be looked up in W; +if the Id is not in the sparse table, this operator will return a +random value and set the value into the table for the next looking up. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(distributed_push_sparse, ops::DistributedPushSparseOp, + ops::DistributedPushSparseOpMaker); + +REGISTER_OP_CPU_KERNEL( + distributed_push_sparse, + ops::DistributedPushSparseKernel, + ops::DistributedPushSparseKernel); diff --git a/paddle/fluid/operators/pscore/distributed_push_sparse_op.cu.cc b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cu.cc new file mode 100644 index 0000000000..5c4ae3bdcf --- /dev/null +++ b/paddle/fluid/operators/pscore/distributed_push_sparse_op.cu.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/pscore/distributed_push_sparse_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + distributed_push_sparse, + ops::DistributedPushSparseKernel, + ops::DistributedPushSparseKernel); diff --git a/paddle/fluid/operators/pscore/distributed_push_sparse_op.h b/paddle/fluid/operators/pscore/distributed_push_sparse_op.h new file mode 100644 index 0000000000..1e27411ad6 --- /dev/null +++ b/paddle/fluid/operators/pscore/distributed_push_sparse_op.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/distributed/fleet.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class DistributedPushSparseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto &scope = context.scope(); + + auto padding_idx = context.Attr("padding_idx"); + auto table_id = context.Attr("table_id"); + auto emb_dim = context.Attr("size"); + VLOG(1) << "push_sparse.h::emb_dim: " << emb_dim; + bool is_test = context.Attr("is_test"); + + auto inputs = context.MultiInput("Ids"); + auto shows = context.Input("Shows"); + auto clks = context.Input("Clicks"); + auto outputs = context.MultiOutput("Outputs"); + + auto fleet = distributed::FleetWrapper::GetInstance(); + + if (platform::is_cpu_place(context.GetPlace())) { + fleet->PushSparseFromTensorAsync(static_cast(table_id), emb_dim, + static_cast(padding_idx), + context.GetPlace(), &inputs, shows, clks, + &outputs); + } else { + auto inputs_variable = context.MultiInputVar("Ids"); + auto outputs_variable = context.MultiOutputVar("Outputs"); + auto inputs_name = context.InputNames("Ids"); + auto outputs_name = context.OutputNames("Outputs"); + + auto cpu_place = platform::CPUPlace(); + framework::Scope *tmp_scope = scope.NewTmpScope().release(); + + std::vector tmp_input_vec; + auto input_var_size = inputs_variable.size(); + std::vector tmp_output_vec; + auto output_var_size = outputs_variable.size(); + + // create temp input + for (size_t idx = 0; idx < input_var_size; ++idx) { + framework::Variable *tmp_input_var = tmp_scope->Var(inputs_name[idx]); + framework::LoDTensor *tmp_input_tensor = + tmp_input_var->GetMutable(); + framework::TensorCopy(inputs_variable[idx]->Get(), + cpu_place, context.device_context(), + tmp_input_tensor); + tmp_input_vec.push_back(tmp_input_tensor); + } + + // create temp output + for (size_t idx = 0; idx < output_var_size; ++idx) { + framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); + framework::LoDTensor *tmp_output_tensor = + tmp_output_var->GetMutable(); + tmp_output_tensor->Resize(outputs[idx]->dims()); + tmp_output_vec.push_back(tmp_output_tensor); + } + + // use fleet->PullSparse + fleet->PullSparseToTensorSync(static_cast(table_id), emb_dim, + static_cast(padding_idx), + cpu_place, !is_test, &tmp_input_vec, + &tmp_output_vec); + + // cp temp to origin + for (size_t idx = 0; idx < output_var_size; ++idx) { + framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); + framework::LoDTensor *tmp_output_tensor = + tmp_output_var->GetMutable(); + framework::TensorCopy( + *tmp_output_tensor, context.GetPlace(), context.device_context(), + outputs_variable[idx]->GetMutable()); + } + delete tmp_scope; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/pscore/send_op.cc b/paddle/fluid/operators/pscore/send_op.cc index cdb445252b..a496d0d5a0 100644 --- a/paddle/fluid/operators/pscore/send_op.cc +++ b/paddle/fluid/operators/pscore/send_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/distributed/fleet.h" #include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/framework/op_registry.h" @@ -42,14 +43,15 @@ class SendOp : public framework::OperatorBase { const platform::Place& place) const override { auto ins = Inputs("X"); // auto is_sparse = Attr("is_sparse"); - // auto table_id = Attr("table_id"); + auto table_id = Attr("table_id"); auto send_varnames = Attr>("send_varnames"); - auto* communicator = paddle::distributed::Communicator::GetInstance(); - if (communicator->Check(send_varnames)) { - communicator->Send(ins, scope); - } + auto fleet = paddle::distributed::FleetWrapper::GetInstance(); + std::vector<::std::future> status; + // Note: only send push_dense now! + // communicator->Send(ins, scope) can be used to push_sparse or push_dense + fleet->PushDenseVarsAsync(scope, table_id, ins, &status, 0, -1); // auto fleet = paddle::distributed::FleetWrapper::GetInstance(); // if (is_sparse == 0) { diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index f81bbd69a0..aeb4f533f4 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -76,7 +76,9 @@ void BindDistFleetWrapper(py::module* m) { .def("stop_server", &FleetWrapper::StopServer) .def("stop_worker", &FleetWrapper::FinalizeWorker) .def("barrier", &FleetWrapper::BarrierWithTable) - .def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable); + .def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable) + .def("create_client2client_connection", + &FleetWrapper::CreateClient2ClientConnection); } void BindPSHost(py::module* m) { @@ -159,8 +161,11 @@ void BindDistCommunicator(py::module* m) { .def("push_sparse_param", &Communicator::RpcSendSparseParam) .def("is_running", &Communicator::IsRunning) .def("init_params", &Communicator::InitParams) - .def("pull_dense", &Communicator::PullDense); - // .def("recv", &Communicator::RecvNoBarrier); + .def("pull_dense", &Communicator::PullDense) + .def("create_client_to_client_connection", + &Communicator::CreateC2CConnection) + .def("get_client_info", &Communicator::GetClientInfo) + .def("set_clients", &Communicator::SetClients); } void BindHeterClient(py::module* m) { diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 600327e4a5..fc299bc7b5 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -48,6 +48,7 @@ from .fleet import BoxPSDataset # noqa: F401 from .entry_attr import ProbabilityEntry # noqa: F401 from .entry_attr import CountFilterEntry # noqa: F401 +from .entry_attr import ShowClickEntry # noqa: F401 from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401 @@ -69,6 +70,7 @@ __all__ = [ # noqa "QueueDataset", "split", "CountFilterEntry", + "ShowClickEntry", "get_world_size", "get_group", "all_gather", diff --git a/python/paddle/distributed/entry_attr.py b/python/paddle/distributed/entry_attr.py index d74a46f530..1b3e40ec34 100644 --- a/python/paddle/distributed/entry_attr.py +++ b/python/paddle/distributed/entry_attr.py @@ -137,3 +137,45 @@ class CountFilterEntry(EntryAttr): def _to_attr(self): return ":".join([self._name, str(self._count_filter)]) + + +class ShowClickEntry(EntryAttr): + """ + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + + sparse_feature_dim = 1024 + embedding_size = 64 + + shows = paddle.static.data(name='show', shape=[1], dtype='int64') + clicks = paddle.static.data(name='click', shape=[1], dtype='int64') + input = paddle.static.data(name='ins', shape=[1], dtype='int64') + + entry = paddle.distributed.ShowClickEntry("show", "click") + + emb = paddle.static.nn.sparse_embedding( + input=input, + size=[sparse_feature_dim, embedding_size], + is_test=False, + entry=entry, + param_attr=paddle.ParamAttr(name="SparseFeatFactors", + initializer=paddle.nn.initializer.Uniform())) + + + """ + + def __init__(self, show_name, click_name): + super(ShowClickEntry, self).__init__() + + if not isinstance(show_name, str) or not isinstance(click_name, str): + raise ValueError("show_name click_name must be a str") + + self._name = "show_click_entry" + self._show_name = show_name + self._click_name = click_name + + def _to_attr(self): + return ":".join([self._name, self._show_name, self._click_name]) diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 975c7b3f74..cdbc7bd0cd 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -392,6 +392,38 @@ class DistributedStrategy(object): """ return get_msg_dict(self.strategy.trainer_desc_configs) + @property + def adam_d2sum(self): + """ + set adam_d2sum + Default value: True + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + role_maker = fleet.PaddleCloudRoleMaker() + fleet.init(role_maker) + + strategy = fleet.DistributedStrategy() + strategy.adam_d2sum = True # by default this is True + + # code block for defining loss and local optimizer + # sgd = fleet.distributed_optimizer(optimizer, strategy) + """ + return self.strategy.adam_d2sum + + @adam_d2sum.setter + @is_strict_auto + def adam_d2sum(self, flag): + if isinstance(flag, bool): + self.strategy.adam_d2sum = flag + else: + raise ValueError( + "The type of `flag` is invalid, expected type is bool, but received {}". + format(type(flag))) + @trainer_desc_configs.setter @is_strict_auto def trainer_desc_configs(self, configs): @@ -399,6 +431,62 @@ class DistributedStrategy(object): "trainer_desc_configs") assign_configs_value(self.strategy.trainer_desc_configs, configs) + @property + def fs_client_param(self): + """ + Set fs client configurations. + **Notes**: + uri(str): the uri of fs client + user(str): the user_name of fs client + passwd(str): the passwd of fs client + hadoop_bin(str): + Examples: + .. code-block:: python + import paddle.distributed.fleet as fleet + role_maker = fleet.PaddleCloudRoleMaker() + fleet.init(role_maker) + strategy = fleet.DistributedStrategy() + configs = {"uri": "xxx", "user": "xxx", passwd: "xxx"} + strategy.fs_client_param = configs + # code block for defining loss and local optimizer + # sgd = fleet.distributed_optimizer(optimizer, strategy) + """ + return self.strategy.fs_client_param + + @fs_client_param.setter + @is_strict_auto + def fs_client_param(self, configs): + check_configs_key(self.strategy.fs_client_param, configs, + "fs_client_param") + assign_configs_value(self.strategy.fs_client_param, configs) + + @property + def sparse_table_configs(self): + return self.strategy.downpour_table_param + + @sparse_table_configs.setter + @is_strict_auto + def sparse_table_configs(self, configs): + from google.protobuf.descriptor import FieldDescriptor + table_param = self.strategy.downpour_table_param + + def set_table_config(msg, config_name, configs): + for field in msg.DESCRIPTOR.fields: + name = config_name + "." + field.name + if field.type == FieldDescriptor.TYPE_MESSAGE: + print("message:", name) + set_table_config(getattr(msg, field.name), name, configs) + else: + print("not message:", name) + if name not in configs: + continue + if field.label == FieldDescriptor.LABEL_REPEATED: + getattr(msg, field.name).extend(configs[name]) + else: + setattr(msg, field.name, configs[name]) + + set_table_config(table_param, "table_parameters", configs) + @property def amp(self): """ diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index dc555b5ae2..81613cc1ef 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -24,6 +24,7 @@ from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.framework import Variable, Parameter from .runtime_base import RuntimeBase from ..base.private_helper_function import wait_server_ready +import paddle.distributed.fleet as fleet __all__ = [] @@ -49,7 +50,7 @@ def parse_table_class(varname, o_main_program): if op.has_attr('table_class') and op.attr("table_class") != "none": return op.attr('table_class') else: - return "CommonSparseTable" + return "MemorySparseTable" class Accessor: @@ -95,6 +96,11 @@ class CommonAccessor: opt_input_map["adam"] = [("Param", None), ("Moment1", None), ("Moment2", None), ("Beta1Pow", 1), ("Beta2Pow", 1), ("LearningRate", 1)] + opt_input_map["adam_d2sum"] = [ + ("Param", None), ("D2Sum", None), ("G2Sum", None), ("Moment", None), + ("MomentDecayRate", 1), ("AdaDecayRate", 1), ("AdaEpsilon", 1), + ("LearningRate", 1) + ] opt_input_map["sum"] = [("Param", None)] opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1), ("LearningRate", 1)] @@ -105,6 +111,8 @@ class CommonAccessor: opt_attr_map["naive_adagrad"] = [] opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"), ("epsilon", "f")] + opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), + ("epsilon", "f")] opt_init_map = {} opt_init_map["gaussian_random"] = ["seed", "mean", "std"] @@ -162,7 +170,7 @@ class CommonAccessor: return attr_str def parse_by_optimizer(self, grad_name, is_sparse, total_dims, - compiled_strategy): + compiled_strategy, adam_d2sum): from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_optimize_ops param_name = compiled_strategy.grad_name_to_param_name[grad_name] main_program, startup_program = compiled_strategy.get_origin_programs() @@ -187,6 +195,10 @@ class CommonAccessor: self.trainer_num = compiled_strategy.get_trainers() + if oop.type != 'adam' and adam_d2sum == True: + print('optimization algorithm is not adam, set adam_d2sum False') + adam_d2sum = False + print("adam_d2sum:", adam_d2sum) if compiled_strategy.is_geo_mode(): param_varnames = self.opt_input_map["sum"] attr_varnames = self.opt_attr_map["sum"] @@ -195,6 +207,10 @@ class CommonAccessor: param_varnames = self.opt_input_map["naive_adagrad"] attr_varnames = self.opt_attr_map["naive_adagrad"] self.accessor_class = "sgd" + elif adam_d2sum: + param_varnames = self.opt_input_map["adam_d2sum"] + attr_varnames = self.opt_attr_map["adam_d2sum"] + self.accessor_class = "adam_d2sum" else: param_varnames = self.opt_input_map[oop.type] attr_varnames = self.opt_attr_map[oop.type] @@ -202,17 +218,8 @@ class CommonAccessor: for (formal_name, shape) in param_varnames: params.append(formal_name) - if formal_name == "G2Sum": - dims.append(1) - initializer = "fill_constant&0" - initializers.append(initializer) - else: - param = main_program.global_block().vars[oop.input(formal_name)[ - 0]] - if formal_name == "LearningRate" and param.name != "learning_rate_0": - warnings.warn("will support decay soon") - param = main_program.global_block().vars["learning_rate_0"] - + if self.accessor_class == "adam_d2sum": + #for dims if shape is None: if is_sparse: shape = total_dims @@ -221,9 +228,51 @@ class CommonAccessor: pserver_id) dims.append(shape) - initializer = self.get_initializer_attr(param.name, - startup_program) + #for initializers + if formal_name == "Param" or formal_name == "LearningRate": + param = main_program.global_block().vars[oop.input( + formal_name)[0]] + #TODO: for dense learning_rate, can be different from sparse lr + if formal_name == "LearningRate" and param.name != "learning_rate_0": + warnings.warn("will support decay soon") + param = main_program.global_block().vars[ + "learning_rate_0"] + + initializer = self.get_initializer_attr(param.name, + startup_program) + elif formal_name == "MomentDecayRate": + initializer = "fill_constant&0.99" + elif formal_name == "AdaDecayRate": + initializer = "fill_constant&0.9999" + elif formal_name == "AdaEpsilon": + initializer = "fill_constant&1.0e-8" + else: + initializer = "fill_constant&0" initializers.append(initializer) + else: + if formal_name == "G2Sum": + dims.append(1) + initializer = "fill_constant&0" + initializers.append(initializer) + else: + param = main_program.global_block().vars[oop.input( + formal_name)[0]] + if formal_name == "LearningRate" and param.name != "learning_rate_0": + warnings.warn("will support decay soon") + param = main_program.global_block().vars[ + "learning_rate_0"] + + if shape is None: + if is_sparse: + shape = total_dims + else: + shape = self.get_shard(total_dims, pserver_num, + pserver_id) + dims.append(shape) + + initializer = self.get_initializer_attr(param.name, + startup_program) + initializers.append(initializer) for (attr_varname, type_) in attr_varnames: value = oop.attr(attr_varname) @@ -292,6 +341,7 @@ class Table: self.accessor = None self.common = None self.tensor = None + self.accessor_proto = None def to_string(self, indent): table_str = "{}downpour_table_param {{{}\n{}}}" @@ -304,6 +354,14 @@ class Table: attrs += "\n" indent += 2 + if self.accessor_proto is not None: + accessor_str = "{}accessor {{{}\n{}}}" + accessor_str = accessor_str.format( + conv_indent(indent), self.accessor_proto, conv_indent(indent)) + attrs += accessor_str + "\n" + return table_str.format( + conv_indent(indent), attrs, conv_indent(indent)) + if self.accessor is not None: attrs += self.accessor.to_string(indent) attrs += "\n" @@ -431,6 +489,24 @@ class Worker: return worker_str.format(workers_str) +class fsClient: + def __init__(self, proto): + self.proto = proto + self.uri = proto.uri + self.user = proto.user + self.passwd = proto.passwd + self.hadoop_bin = proto.hadoop_bin + + def to_string(self): + from google.protobuf import text_format + proto_txt = text_format.MessageToString(self.proto) + if proto_txt: + fs_str = "fs_client_param {{\n{}}}" + return fs_str.format(proto_txt) + else: + return "" + + class TheOnePSRuntime(RuntimeBase): def __init__(self): super(TheOnePSRuntime, self).__init__() @@ -533,7 +609,6 @@ class TheOnePSRuntime(RuntimeBase): trainer_config = self.async_strategy.get_trainer_runtime_config() debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) - if debug: print("worker: \n{}".format(proto_txt)) print("communicator send_ctx:") @@ -565,6 +640,25 @@ class TheOnePSRuntime(RuntimeBase): self._communicator.init_with_ctx(send_ctx, dense_map, proto_txt, string_hosts, fluid.global_scope()) + import paddle.distributed.fleet as fleet + fleet.util.barrier() + info = self._communicator.get_client_info() + if isinstance(info, list) and len(info) > 0: + all_info = self.role_maker._all_gather(info[0]) + # for unittest + if not isinstance(all_info, list): + warnings.warn("gloo may not initialize correctly") + all_info = [all_info] + self._communicator.set_clients(all_info) + # create_c2c_connection default param: + # pserver_timeout_ms=500000 + # pserver_connect_timeout_ms=10000 + # max_retry=3 + self._communicator.create_client_to_client_connection() + print('create c2c connection done') + else: + print('cannot create c2c connection') + dist_strategy = self.context["valid_strategy"] is_test = bool(int(os.getenv("TEST_MODE", "0"))) @@ -577,7 +671,6 @@ class TheOnePSRuntime(RuntimeBase): else: init_params = dense_map - import paddle.distributed.fleet as fleet if not is_test: self._communicator.init_params(init_params) fleet.util.barrier() @@ -632,7 +725,7 @@ class TheOnePSRuntime(RuntimeBase): int(os.getenv("FLAGS_selected_xpus", "0")))) return executor - def _get_fleet_proto(self, is_server, is_sync): + def _get_fleet_proto(self, is_server, is_sync, **kwargs): def _build_merge_accessor(ctx): accessor = Accessor() accessor.accessor_class = "CommMergeAccessor" @@ -736,6 +829,7 @@ class TheOnePSRuntime(RuntimeBase): tables = [] for idx, (name, ctx) in enumerate(send_ctx.items()): + print(" wxm python test send_ctx.items-->", idx, (name, ctx)) if ctx.is_tensor_table() or len(ctx.origin_varnames()) < 1: continue @@ -755,18 +849,64 @@ class TheOnePSRuntime(RuntimeBase): else: table.table_class = parse_table_class( common.table_name, self.origin_main_program) - + table_proto = self.context[ + "user_defined_strategy"].sparse_table_configs + table.shard_num = table_proto.shard_num + from google.protobuf import text_format + table.accessor_proto = text_format.MessageToString( + table_proto.accessor) + + print('table proto:', table_proto) + if table.table_class == 'MemorySparseTable' and table.accessor_proto == '': + emb_dim = ctx.sections()[1] + table.shard_num = 1950 + table.accessor_proto = 'accessor_class: "CtrCommonAccessor"\n' \ + 'embed_sgd_param {\n' \ + ' name: "SparseAdaGradSGDRule"\n' \ + ' adagrad {\n' \ + ' learning_rate: 0.05\n' \ + ' initial_g2sum: 3.0\n' \ + ' initial_range: 0.0001\n' \ + ' weight_bounds: -10.0\n' \ + ' weight_bounds: 10.0\n' \ + ' }\n' \ + '}\n' \ + 'embedx_sgd_param {\n' \ + ' name: "SparseAdaGradSGDRule"\n' \ + ' adagrad {\n' \ + ' learning_rate: 0.05\n' \ + ' initial_g2sum: 3.0\n' \ + ' initial_range: 0.0001\n' \ + ' weight_bounds: -10.0\n' \ + ' weight_bounds: 10.0\n' \ + ' }\n' \ + '}\n' \ + 'fea_dim: ' + str(emb_dim+2) + '\n' \ + 'embedx_dim: ' + str(emb_dim-1) + '\n' \ + 'embedx_threshold: 10\n' \ + 'ctr_accessor_param {\n' \ + ' nonclk_coeff: 0.1\n' \ + ' click_coeff: 1.0\n' \ + ' base_threshold: 1.5\n' \ + ' delta_threshold: 0.25\n' \ + ' delta_keep_days: 16.0\n' \ + ' show_click_decay_rate: 0.98\n' \ + ' delete_threshold: 0.8\n' \ + ' delete_after_unseen_days: 30.0\n' \ + ' ssd_unseenday_threshold: 1\n' \ + '}' else: table.type = "PS_DENSE_TABLE" table.table_class = "CommonDenseTable" table.shard_num = 256 common.table_name = "MergedDense" + adam_d2sum = self.context["user_defined_strategy"].adam_d2sum common.parse_by_optimizer(ctx.origin_varnames()[0], ctx.is_sparse(), ctx.sections()[1] if ctx.is_sparse() else ctx.sections()[0], - self.compiled_strategy) + self.compiled_strategy, adam_d2sum) if ctx.is_sparse(): common.parse_entry(common.table_name, @@ -779,8 +919,9 @@ class TheOnePSRuntime(RuntimeBase): table.common = common - accessor = _build_merge_accessor(ctx) - table.accessor = accessor + if table.table_class != 'MemorySparseTable': + accessor = _build_merge_accessor(ctx) + table.accessor = accessor tables.append(table) tensor_table_dict = self.compiled_strategy.get_tensor_table_dict() @@ -828,6 +969,9 @@ class TheOnePSRuntime(RuntimeBase): trainers += len(self.role_maker._get_heter_worker_endpoints()) server = self._get_fleet_proto(is_server=True, is_sync=is_sync) proto_txt = str(server) + fs_client = fsClient(self.context["user_defined_strategy"] + .fs_client_param) + proto_txt = proto_txt + "\n" + fs_client.to_string() debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) if debug: @@ -924,10 +1068,14 @@ class TheOnePSRuntime(RuntimeBase): for id, names in context.items(): if names[0] not in distributed_varnames: # only save sparse param to local - self._worker.recv_and_save_model(id, dirname) + try: + self._worker.recv_and_save_model(id, dirname) + except: + pass # save sparse & distributed param on server self._worker.save_one_model(id, dirname, mode) values.extend(names) + # self._worker.save_all_model(dirname, mode) return values def _save_distributed_persistables(self, @@ -951,6 +1099,7 @@ class TheOnePSRuntime(RuntimeBase): recv_dense_varnames = [] for id, names in denses.items(): recv_dense_varnames.extend(names) + self._communicator.pull_dense(denses) saved_varnames = sparse_varnames @@ -1004,8 +1153,9 @@ class TheOnePSRuntime(RuntimeBase): ) # Todo(MrChengmo): Save optimizer status - self._save_distributed_persistables(executor, dirname, main_program, - mode) + # self._save_distributed_persistables(executor, dirname, main_program, + # mode) + self._worker.save_all_model(dirname, mode) def _ps_inference_save_inference_model(self, executor, @@ -1046,10 +1196,45 @@ class TheOnePSRuntime(RuntimeBase): infer_program._copy_dist_param_info_from(program) + if dirname.startswith("afs:") or dirname.startswith("hdfs:"): + model_path = "./dnn_plugin" + else: + model_path = os.path.join(dirname, "dnn_plugin") model_basename = "__model__" - model_basename = os.path.join(dirname, model_basename) + model_basename = os.path.join(model_path, model_basename) paddle.save(infer_program, model_basename) + sparses = self.compiled_strategy.get_the_one_recv_context( + is_dense=False, + split_dense_table=self.role_maker._is_heter_parameter_server_mode, + use_origin_program=True) + print("the one ps sparses:", sparses) + sparse_names = [] + for id, name in sparses.items(): + sparse_names.extend(name) + print("the one ps sparse names:", sparse_names) + + denses = self.compiled_strategy.get_the_one_recv_context( + is_dense=True, + split_dense_table=self.role_maker._is_heter_parameter_server_mode, + use_origin_program=True) + self._communicator.pull_dense(denses) + + generate_vars = self.context[ + "user_defined_strategy"].trainer_desc_configs["stat_var_names"] + generate_vars = [var for var in generate_vars] + remaining_vars = list( + filter( + TheOnePSRuntime.__exclude_vars(generate_vars + sparse_names), + infer_program.list_vars())) + print("remain_vars:", [var.name for var in remaining_vars]) + for var in remaining_vars: + tensor = var.get_value() + paddle.save( + tensor, + os.path.join(model_path, var.name), + use_binary_format=True) + self._ps_inference_save_persistables(executor, dirname, infer_program, mode) @@ -1073,8 +1258,10 @@ class TheOnePSRuntime(RuntimeBase): values.extend(names) return values - def _load_distributed_persistables(self, dirname, main_program=None, - mode=0): + def _ps_inference_load_inference_model(self, + dirname, + mode=0, + main_program=None): if main_program is None: main_program = self.compiled_strategy.get_origin_ps_main_program() @@ -1106,17 +1293,27 @@ class TheOnePSRuntime(RuntimeBase): TheOnePSRuntime.__exclude_vars(loaded_varnames), main_program.list_vars())) + if dirname.startswith("afs:") or dirname.startswith("hdfs:"): + model_path = "./dnn_plugin" + else: + model_path = os.path.join(dirname, "dnn_plugin") import paddle for var in remaining_vars: if var.name not in recv_dense_varnames: continue - tensor = paddle.load(os.path.join(dirname, var.name)) + tensor = paddle.load(os.path.join(model_path, var.name)) var.set_value(tensor) self._communicator.init_params(denses) + def _load_distributed_persistables(self, path, mode): + self._worker.load_model(path, mode) + def load_model(self, path, mode): - self._load_distributed_persistables(path, mode=mode) + if mode == 0 or mode == 3: + self._load_distributed_persistables(path, mode) + else: + self._ps_inference_load_inference_model(path, mode) def _shrink(self, threshold): import paddle.distributed.fleet as fleet diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index eb8739b15b..392edb65ba 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -99,6 +99,19 @@ class Communicator(object): self.send_ctx_ = send_ctx self.recv_ctx_ = recv_ctx + def create_client_to_client_connection(self, + pserver_timeout_ms=500000, + pserver_connect_timeout_ms=10000, + max_retry=3): + self.communicator_.create_client_to_client_connection( + pserver_timeout_ms, pserver_connect_timeout_ms, max_retry) + + def get_client_info(self): + return self.communicator_.get_client_info() + + def set_clients(self, host_list): + self.communicator_.set_clients(host_list) + def start(self): """ Start communicator. Should call before training process. diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index cb26f05b54..3f00b49dc3 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -968,7 +968,7 @@ def sparse_embedding(input, padding_idx=None, is_test=False, entry=None, - table_class="CommonSparseTable", + table_class="MemorySparseTable", param_attr=None, dtype='float32'): r""" @@ -1100,18 +1100,21 @@ def sparse_embedding(input, padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( size[0] + padding_idx) - if table_class not in ["CommonSparseTable", "SSDSparseTable"]: + if table_class not in [ + "CommonSparseTable", "SSDSparseTable", "MemorySparseTable" + ]: raise ValueError( - "table_class must be in [CommonSparseTable, SSDSparseTable]") + "table_class must be in [CommonSparseTable, SSDSparseTable, MemorySparseTable]" + ) entry_str = "none" if entry is not None: if entry.__class__.__name__ not in [ - "ProbabilityEntry", "CountFilterEntry" + "ProbabilityEntry", "CountFilterEntry", "ShowClickEntry" ]: raise ValueError( - "entry must be instance in [paddle.distributed.ProbabilityEntry, paddle.distributed.CountFilterEntry]" + "entry must be instance in [paddle.distributed.ProbabilityEntry, paddle.distributed.CountFilterEntry, paddle.distributed.ShowClickEntry]" ) entry_str = entry._to_attr() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 59d26f4837..11fa70b70b 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -43,6 +43,10 @@ OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"} +SPARSE_GRAD_OP_TYPE_DICT = { + "lookup_table_grad": "W", + "lookup_table_v2_grad": "W" +} DEVICE_LIST = ["cpu", "gpu", "xpu"] COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"] DEFAULT_DEVICE = 'cpu' @@ -98,9 +102,14 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): trainer_id = config.get_role_id() send_ctx = config.get_the_one_send_context( split_dense_table=config.is_heter_ps_mode) + w_2_table_id = {} + emb_size = {} def _get_pull_sparse_ops(_program): pull_sparse_ops = {} + pull_sparse_ids = {} + push_sparse_ops = {} + ops = {} for op in _program.global_block().ops: if op.type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: @@ -111,7 +120,18 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ops = pull_sparse_ops.get(param_name, []) ops.append(op) pull_sparse_ops[param_name] = ops - return pull_sparse_ops + ids = pull_sparse_ids.get(param_name, []) + ids.append(op.input("Ids")[0]) + pull_sparse_ids[param_name] = ids + for op in _program.global_block().ops: + if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys(): + param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0] + if param_name in pull_sparse_ids and op.input("Ids")[ + 0] in pull_sparse_ids[param_name]: + ops = push_sparse_ops.get(param_name, []) + ops.append(op) + push_sparse_ops[param_name] = ops + return pull_sparse_ops, push_sparse_ops def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu): def dag_check_up_and_reorder(program, inputs, outputs): @@ -218,6 +238,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): program.global_block().vars[op.input("Ids")[0]] for op in ops ] w = program.global_block().vars[ops[0].input("W")[0]] + emb_size[param] = w.shape[1] grad_name = config.param_name_to_grad_name[w.name] @@ -231,6 +252,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): raise ValueError( "can not find suitable sparse table, please check") + w_2_table_id[param] = table_id padding_idx = ops[0].attr("padding_idx") is_distributed = ops[0].attr("is_distributed") op_type = ops[0].type @@ -263,7 +285,6 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): outputs_idxs[out_id]) if min(outputs_idxs) - max(inputs_idxs) >= 1: - if max(inputs_idxs) == -1: distributed_idx = min(op_idxs) else: @@ -313,8 +334,123 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): "op_device": op_device }) - pull_sparse_ops = _get_pull_sparse_ops(program) + def _push_sparse_fuse(_program, push_sparse_ops, use_ps_gpu): + if use_ps_gpu: + # in ps_gpu_pass + return + if len(push_sparse_ops) == 0: + return + show = None + clk = None + use_entry = False + for param, ops in push_sparse_ops.items(): + op_first = ops[0] + break + print(op_first) + if op_first.has_attr("entry"): + entry = op_first.attr("entry") + entry = entry.split(':') + if len(entry) == 3 and entry[0] == 'show_click_entry': + show_var_name = entry[1] + click_var_name = entry[2] + if show_var_name in program.global_block( + ).vars and click_var_name in program.global_block().vars: + show = program.global_block().vars[show_var_name] + clk = program.global_block().vars[click_var_name] + use_entry = True + else: + warnings.warn( + 'ShowClickEntry configured, but cannot find show/click var, will not use' + ) + + if not use_entry: + print('ShowClickEntry not configured, will not use') + show = program.global_block().create_var( + name="show", + dtype=core.VarDesc.VarType.INT64, + persistable=False, + stop_gradient=True) + program.global_block()._insert_op( + index=0, + type='fill_constant', + inputs={}, + outputs={'Out': show}, + attrs={ + 'shape': [1], + 'dtype': show.dtype, + 'value': 1, + #OP_ROLE_KEY: OpRole.Forward + }) + + clk = program.global_block().create_var( + name="clk", + dtype=core.VarDesc.VarType.INT64, + persistable=False, + stop_gradient=True) + program.global_block()._insert_op( + index=0, + type='fill_constant', + inputs={}, + outputs={'Out': clk}, + attrs={ + 'shape': [1], + 'dtype': clk.dtype, + 'value': 0, + #OP_ROLE_KEY: OpRole.Forward + }) + + for param, ops in push_sparse_ops.items(): + all_ops = program.global_block().ops + op_idxs = [all_ops.index(op) for op in ops] + inputs = [ + program.global_block().vars[op.input("Ids")[0]] for op in ops + ] + w = program.global_block().vars[ops[0].output("W@GRAD")[0]] + table_id = w_2_table_id[param] + + padding_idx = ops[0].attr("padding_idx") + is_distributed = ops[0].attr("is_distributed") + op_type = ops[0].type + outputs = [ + program.global_block().vars[op.input("Out@GRAD")[0]] + for op in ops + ] + + for idx in op_idxs[::-1]: + program.global_block()._remove_op(idx) + +# if use_ps_gpu: +# program.global_block().append_op( +# type="push_box_sparse", +# inputs={"Ids": inputs, +# 'Out': outputs}, +# outputs={"Out": outputs}, +# attrs={ +# "size": w.shape[1], +# "is_distributed": True, +# "is_sparse": True +# }) +# else: + program.global_block().append_op( + type="distributed_push_sparse", + inputs={ + "Ids": inputs, + 'W': w, + "Outputs": outputs, + "Shows": show, + "Clicks": clk + }, + outputs={"Outputs": outputs}, + attrs={ + "is_distributed": is_distributed, + "padding_idx": padding_idx, + "table_id": table_id, + "size": emb_size[param] + }) + + pull_sparse_ops, push_sparse_ops = _get_pull_sparse_ops(program) _pull_sparse_fuse(program, pull_sparse_ops, use_ps_gpu) + _push_sparse_fuse(program, push_sparse_ops, use_ps_gpu) return program @@ -367,6 +503,8 @@ def append_send_ops_pass(program, config): split_dense_table=config.is_heter_ps_mode) for merged_name, send in sends.items(): + if send.is_sparse(): + continue is_sparse = 1 if send.is_sparse() else 0 is_sparse = 2 if send.is_distributed() else is_sparse dummys.append( diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 2a8ee8bc72..65c8a7500f 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -156,6 +156,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): return avg_cost def check_model_right(self, dirname): + dirname = dirname + '/dnn_plugin/' model_filename = os.path.join(dirname, "__model__") with open(model_filename, "rb") as f: diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py index ad2b66f3c2..d013266e83 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py @@ -98,11 +98,13 @@ class TestDistCTR2x2(FleetDistRunnerBase): else: raise ValueError("error initializer code: {}".format(initializer)) + entry = paddle.distributed.ShowClickEntry("show", "click") dnn_layer_dims = [128, 64, 32] dnn_embedding = fluid.contrib.layers.sparse_embedding( input=dnn_data, size=[dnn_input_dim, dnn_layer_dims[0]], is_test=inference, + entry=entry, param_attr=fluid.ParamAttr( name="deep_embedding", initializer=init)) dnn_pool = fluid.layers.sequence_pool( @@ -123,6 +125,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): input=lr_data, size=[lr_input_dim, 1], is_test=inference, + entry=entry, param_attr=fluid.ParamAttr( name="wide_embedding", initializer=fluid.initializer.Constant(value=0.01))) diff --git a/python/paddle/fluid/tests/unittests/test_entry_attr.py b/python/paddle/fluid/tests/unittests/test_entry_attr.py index efcad103de..bdfe95560e 100644 --- a/python/paddle/fluid/tests/unittests/test_entry_attr.py +++ b/python/paddle/fluid/tests/unittests/test_entry_attr.py @@ -19,7 +19,7 @@ paddle.enable_static() import unittest import paddle.fluid as fluid -from paddle.distributed import ProbabilityEntry, CountFilterEntry +from paddle.distributed import ProbabilityEntry, CountFilterEntry, ShowClickEntry class EntryAttrChecks(unittest.TestCase): @@ -51,6 +51,11 @@ class EntryAttrChecks(unittest.TestCase): with self.assertRaises(ValueError): counter2 = CountFilterEntry(-1) + def showclick_entry(self): + showclick = ShowClickEntry("show", "click") + ss = showclick._to_attr() + self.assertEqual("show_click_entry:show:click", ss) + def spaese_layer(self): prog = fluid.Program() scope = fluid.core.Scope() @@ -97,6 +102,9 @@ class TestEntryAttrs(EntryAttrChecks): def test_counter(self): self.countfilter_entry() + def test_showclick(self): + self.showclick_entry() + def test_spaese_embedding_layer(self): self.spaese_layer() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py index 88e5ea2044..3078e5b3d1 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py @@ -76,6 +76,7 @@ class TestFleetBase(unittest.TestCase): fleet.fleet.save(dirname="/tmp") fleet.load_model(path="/tmp", mode=0) + fleet.load_model(path="/tmp", mode=1) self.assertRaises( Exception, @@ -94,6 +95,15 @@ class TestFleetBase(unittest.TestCase): executor=exe, main_program=compiled_prog) + self.assertRaises( + Exception, + fleet.save_inference_model, + dirname='afs:/tmp/', + feeded_var_names=['x', 'y'], + target_vars=[avg_cost], + executor=exe, + main_program=compiled_prog) + self.assertRaises( Exception, fleet.save_persistables, executor=pe, dirname='/tmp/') diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 11f5293f7c..9cf3eb251b 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -255,6 +255,25 @@ class TestStrategyConfig(unittest.TestCase): strategy.a_sync_configs = configs self.assertEqual(strategy.a_sync_configs["k_steps"], 1000) + def test_sparse_table_configs(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + configs = { + "table_parameters.accessor.embed_sgd_param.adagrad.learning_rate": + 0.05 + } + strategy.sparse_table_configs = configs + self.assertEqual(strategy.sparse_table_configs.accessor.embed_sgd_param. + adagrad.learning_rate, 0.05) + strategy.adam_d2sum = True + self.assertEqual(strategy.adam_d2sum, True) + strategy.fs_client_param = { + "uri": "123", + "user": "456", + "passwd": "789", + "hadoop_bin": "hadoop" + } + self.assertEqual(strategy.fs_client_param.user, "456") + def test_trainer_desc_configs(self): strategy = paddle.distributed.fleet.DistributedStrategy() configs = { -- GitLab