diff --git a/paddle/fluid/operators/distributed_ops/sparse_tensor_load_op.cc b/paddle/fluid/operators/distributed_ops/sparse_tensor_load_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6cd01089f9bc22a510094d0e994c487525fa35bc --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/sparse_tensor_load_op.cc @@ -0,0 +1,217 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/version.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; + +struct DeserializedDataFunctor { + DeserializedDataFunctor(void **buf, Tensor *tensor, + const platform::Place &place) + : buf_(buf), tensor_(tensor), place_(place) {} + + template + void apply() { + *buf_ = tensor_->mutable_data(place_); + } + + void **buf_; + Tensor *tensor_; + platform::Place place_; +}; + +template +class SparseTensorLoadKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto place = ctx.GetPlace(); + auto filename = ctx.Attr("file_path"); + std::ifstream fin(filename, std::ios::binary); + PADDLE_ENFORCE_EQ(static_cast(fin), true, + platform::errors::Unavailable( + "Load operator fail to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + auto name = ctx.OutputNames("Out")[0]; + VLOG(4) << "Sparse Load Var name: " << name; + auto *out_var = ctx.OutputVar("Out"); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::InvalidArgument( + "The variable %s to be loaded cannot be found.", name)); + PADDLE_ENFORCE_EQ(out_var->IsType(), true, + platform::errors::InvalidArgument( + "SparseLoad OP only support LoDTensor")); + LoadLodTensor(fin, place, out_var, ctx); + } + + void LoadLodTensor(std::istream &is, const platform::Place &place, + paddle::framework::Variable *var, + const paddle::framework::ExecutionContext &ctx) const { + auto *tensor = var->GetMutable(); + + auto node_index = ctx.Attr("node_index"); + auto node_num = ctx.Attr("node_num"); + auto shape = ctx.Attr>("shape"); + VLOG(4) << "Sparse LoadLodTensor node_num" << node_num; + VLOG(4) << "Sparse LoadLodTensor node_index" << node_index; + VLOG(4) << "Sparse LoadLodTensor shape[0]" << shape[0]; + PADDLE_ENFORCE_GE(node_index, 0, platform::errors::InvalidArgument( + "node_num great than or equal to 0")); + PADDLE_ENFORCE_GE(node_num, 1, platform::errors::InvalidArgument( + "node_num great than or equal to 1")); + + { + // the 1st field, unit32_t version for LoDTensor + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version), + true, + platform::errors::InvalidArgument( + "Tensor version %u is not supported.", version)); + PADDLE_ENFORCE_EQ(version, 0U, platform::errors::InvalidArgument( + "Tensor version %u is not supported, " + "only version 0 is supported.", + version)); + } + + { + // the 2st field, LoD information + // Todo sparse load need change LoDTensor's lod level + uint64_t lod_level; + is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); + auto &lod = *tensor->mutable_lod(); + lod.resize(lod_level); + } + + // the 3st filed, Tensor + + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + + PADDLE_ENFORCE_EQ( + version, 0U, + platform::errors::InvalidArgument( + "tensor version %u is not supported, Only version 0 is supported", + version)); + + paddle::framework::proto::VarType::TensorDesc desc; + + { // int32_t size + // proto buffer + int32_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + is.read(reinterpret_cast(buf.get()), size); + PADDLE_ENFORCE_EQ( + desc.ParseFromArray(buf.get(), size), true, + platform::errors::InvalidArgument("Cannot parse tensor desc")); + } + + { // read tensor + std::vector dims; + dims.reserve(static_cast(desc.dims().size())); + std::copy(desc.dims().begin(), desc.dims().end(), + std::back_inserter(dims)); + + int64_t line_numel = 1; + for (size_t dim = 1; dim < dims.size(); dim++) { + line_numel *= dims[dim]; + } + auto total_line = dims[0]; + + tensor->Resize(paddle::framework::make_ddim(shape)); + + void *buf; + auto ctx = platform::CPUDeviceContext(); + + paddle::framework::VisitDataType( + desc.data_type(), + DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); + + auto line_size = + line_numel * paddle::framework::SizeOfType(desc.data_type()); + char *cur_buf = static_cast(buf); + char *temp_row = new char[line_size]; + VLOG(4) << "TensorFromStream: line_size " << line_size; + VLOG(4) << "TensorFromStream: total_line " << total_line; + for (size_t line_index = 0; line_index < static_cast(total_line); + ++line_index) { + is.read(temp_row, line_size); + if (static_cast(line_index) % node_num == node_index) { + memcpy(cur_buf, temp_row, line_size); + cur_buf += line_size; + } + } + } + } +}; + +class SparseTensorLoadOp : public paddle::framework::OperatorWithKernel { + public: + using paddle::framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(paddle::framework::InferShapeContext *ctx) const override {} + + protected: + paddle::framework::OpKernelType GetExpectedKernelType( + const paddle::framework::ExecutionContext &ctx) const override { + paddle::framework::OpKernelType kt = paddle::framework::OpKernelType( + paddle::framework::proto::VarType::FP32, ctx.GetPlace()); + return kt; + } +}; + +class SparseTensorLoadOpMaker + : public paddle::framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded"); + AddAttr("file_path", + R"(Variable will be loaded from "file_path")") + .AddCustomChecker( + [](const std::string &path) { return !path.empty(); }); + AddAttr("node_index", "role id from 0 ~ node_num.").SetDefault(0); + AddAttr("node_num", "role nums which need load current varibale.") + .SetDefault(0); + AddAttr>("shape", + "(vector) The shape of the output") + .SetDefault({}); + AddComment(R"DOC( + SparseTensorLoad OP, Load sprase tensor on parameter server + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sparse_tensor_load, ops::SparseTensorLoadOp, + ops::SparseTensorLoadOpMaker); + +REGISTER_OP_CPU_KERNEL( + sparse_tensor_load, + ops::SparseTensorLoadKernel) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index 9990021c8506a386d0084811ae73b97f2ac37ca4..be614a05147385fbd60aee4d9d3269ae0b98d2da 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -31,6 +31,10 @@ class DGCOptimizer(MetaOptimizerBase): loss, role_maker, user_defined_optimizer, user_defined_strategy) opt = self.inner_opt + + if not self.role_maker._is_collective: + return + if not isinstance(opt, Momentum): return diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 266c7d0f405bfd4d1cd24fcf523d94819db4cc47..415e09168088fa4e71467279329c91c2ef207bf6 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -21,6 +21,7 @@ from paddle.fluid.framework import Program from paddle.fluid.compiler import CompiledProgram from paddle.fluid.executor import Executor from paddle.fluid.parallel_executor import ParallelExecutor +from paddle.fluid.framework import Variable, Parameter from .runtime_base import RuntimeBase from ..base.private_helper_function import wait_server_ready @@ -69,7 +70,52 @@ class ParameterServerRuntime(RuntimeBase): self.async_strategy, self.role_maker) return compiled_config - def _load_sparse_params(self, dirname, varnames): + def _load_sparse_params(self, + executor, + dirname, + varnames, + main_program=None): + assert vars != None + check_vars = [] + load_prog = Program() + load_block = load_prog.global_block() + + def _in_varnames(var): + return var.name in varnames + + load_vars = list( + filter(_in_varnames, fluid.default_main_program().list_vars())) + if main_program is None: + main_program = self.origin_main_program + + from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_varname_parts + for each_var in load_vars: + assert isinstance(each_var, Variable) + + origin_varname, _, _ = _get_varname_parts(each_var.name) + + new_var = fluid.io._clone_var_in_block_(load_block, each_var) + var_path = os.path.join(dirname, origin_varname) + if not os.path.exists(var_path): + raise ValueError("SelectedRows var {} can not find at {}". + format(new_var.name, var_path)) + + if os.path.isfile(var_path): + load_block.append_op( + type='sparse_tensor_load', + inputs={}, + outputs={'Out': [new_var]}, + attrs={ + 'file_path': os.path.join(dirname, origin_varname), + 'node_index': self.role_maker._server_index(), + 'node_num': self.role_maker._server_num(), + 'shape': each_var.shape + }) + check_vars.append(each_var) + + executor.run(load_prog) + + def _load_distributed_params(self, dirname, varnames): from paddle.fluid.communicator import LargeScaleKV from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_varname_parts @@ -248,34 +294,54 @@ class ParameterServerRuntime(RuntimeBase): self._init_worker() return - if not model_dirname: - return - - if not os.path.isdir(model_dirname): - raise ValueError("There is no directory named '%s'", model_dirname) - - sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(True) - + sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False) + sparse_related_optimize_varnames = [] + for var_name in sparse_varnames: + sparse_related_optimize_varnames += self.compiled_strategy.get_optimize_varname_on_ps( + var_name) + sparse_related_optimize_varnames = list( + set(sparse_related_optimize_varnames)) distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps( - False) + True) + distributed_related_optimize_varnames = [] + for var_name in distribtued_varnames: + distributed_related_optimize_varnames += self.compiled_strategy.get_optimize_varname_on_ps( + var_name) + distributed_related_optimize_varnames = list( + set(distributed_related_optimize_varnames)) remaining_vars = list( filter( - ParameterServerRuntime.__exclude_vars(sparse_varnames + - distribtued_varnames), + ParameterServerRuntime.__exclude_vars( + sparse_varnames + distribtued_varnames + + sparse_related_optimize_varnames + + distributed_related_optimize_varnames), fluid.default_main_program().list_vars())) + if not model_dirname: + return + + if not os.path.isdir(model_dirname): + raise ValueError("There is no directory named '%s'", model_dirname) + + # load dense fluid.io.load_vars( executor, main_program=fluid.default_main_program(), dirname=model_dirname, vars=remaining_vars) + # load sparse self._load_sparse_params( - dirname=model_dirname, varnames=sparse_varnames) + executor=executor, + dirname=model_dirname, + varnames=sparse_varnames + sparse_related_optimize_varnames) - # todo(tangwei12) load distributed vars - # self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames) + # load large scale + self._load_distributed_params( + dirname=model_dirname, + varnames=distribtued_varnames + + distributed_related_optimize_varnames) def _run_server(self): executor = self._get_executor() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index 90847382c86e1c0bfd2cd9fae33342cbdb38e5ce..fe2ba38ee00b6aaea382b6262d963cc8df8f0cdd 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -236,9 +236,9 @@ class CompileTimeStrategy(object): def get_sparse_varname_on_ps(self, is_distributed, endpoint=None): if not endpoint: endpoint = self.get_ps_endpoint() - varnames = get_sparse_tablenames(self.get_origin_main_program(), is_distributed) + ps_sparse_varnames = [] for varname in varnames: tables = self.get_var_distributed(varname, True) @@ -248,6 +248,55 @@ class CompileTimeStrategy(object): ps_sparse_varnames.append(table) return ps_sparse_varnames + def get_optimize_varname_on_ps(self, param_name): + origin_param_name, _, _ = _get_varname_parts(param_name) + optimize_var_names = [] + for op in self.get_origin_main_program().global_block().ops: + # check all optimizer op + if int(op.all_attrs()["op_role"]) == 2: + # check param name + if op.input("Param")[0] != origin_param_name: + continue + # check all input + for key in op.input_names: + if key in [ + "Param", "Grad", "LearningRate", "Beta1Tensor", + "Beta2Tensor" + ]: + continue + # check varibale shape related param, e.g: Moment1 + optimize_var_names += self._get_optimizer_param_related_var_name( + op, op.type, key) + return optimize_var_names + + def _get_optimizer_param_related_var_name(self, op, op_type, varkey): + """ + Returns the names for optimizer inputs that need to be load + """ + related_var_names = [] + if op_type == "adam": + if varkey in ["Moment1", "Moment2"]: + related_var_names.append(op.input(varkey)[0]) + elif op_type == "adagrad": + if varkey == "Moment": + related_var_names.append(op.input(varkey)[0]) + elif op_type in ["momentum", "lars_momentum"]: + if varkey == "Velocity": + related_var_names.append(op.input(varkey)[0]) + elif op_type == "rmsprop": + if varkey in ["Moment", "MeanSquare"]: + related_var_names.append(op.input(varkey)[0]) + elif op_type == "ftrl": + if varkey in ["SquaredAccumulator", "LinearAccumulator"]: + related_var_names.append(op.input(varkey)[0]) + elif op_type == "sgd": + pass + else: + raise ValueError( + "Not supported optimizer for distributed training: %s" % + op_type) + return related_var_names + def build_ctx(self, vars, mapping, diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_load_ps0.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_load_ps0.py new file mode 100644 index 0000000000000000000000000000000000000000..eddac64bab91b42608899634ed6ceb756d70dcc7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_load_ps0.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet + + +class SparseLoadOp(unittest.TestCase): + """ Test load operator. + """ + + def net(self, emb_array, fc_array): + with fluid.unique_name.guard(): + dense_input = fluid.data('input', shape=[None, 1], dtype="int64") + + emb = fluid.layers.embedding( + input=dense_input, + is_sparse=True, + size=[10, 10], + param_attr=fluid.ParamAttr( + name="embedding", + initializer=fluid.initializer.NumpyArrayInitializer( + emb_array)), ) + + fc1 = fluid.layers.fc( + input=emb, + size=10, + act="relu", + param_attr=fluid.ParamAttr( + name='fc', + initializer=fluid.initializer.NumpyArrayInitializer( + fc_array))) + loss = fluid.layers.reduce_mean(fc1) + return loss + + def save_origin_model(self, emb_array, fc_array): + startup_program = fluid.framework.Program() + test_program = fluid.framework.Program() + with fluid.framework.program_guard(test_program, startup_program): + with fluid.unique_name.guard(): + loss = self.net(emb_array, fc_array) + optimizer = fluid.optimizer.Adam(1e-3) + optimizer.minimize(loss) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + model_path = tempfile.mkdtemp() + fluid.io.save_persistables(executor=exe, dirname=model_path) + return model_path + + +class TestSparseLoadOpCase1(SparseLoadOp): + def test_2ps_0_load(self): + # init No.0 server env + env = {} + env["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002" + env["PADDLE_TRAINERS_NUM"] = str(2) + env["TRAINING_ROLE"] = "PSERVER" + env["PADDLE_PORT"] = "4001" + env["POD_IP"] = "127.0.0.1" + for k, v in env.items(): + os.environ[k] = str(v) + """ + array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], + [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], + [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6], + [0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7], + [0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8], + [0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]]) + """ + emb_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10) + fc_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10) + model_path = self.save_origin_model(emb_array, fc_array) + + role = role_maker.PaddleCloudRoleMaker() + fleet.init(role) + loss = self.net(emb_array, fc_array) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.a_sync = True + optimizer = fluid.optimizer.Adam(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(loss) + fleet.init_server(model_path) + + fc_w = np.array(fluid.global_scope().find_var("fc").get_tensor()) + + emb = np.array(fluid.global_scope().find_var("embedding.block0") + .get_tensor()) + + assert fc_w.all() == fc_array.all() + assert emb.all() == emb_array[::2].all() + shutil.rmtree(model_path) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_load_ps1.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_load_ps1.py new file mode 100644 index 0000000000000000000000000000000000000000..7d14a484f3442c49284618702656ec9985c849f1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_load_ps1.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet +from test_dist_sparse_load_ps0 import SparseLoadOp + + +class TestSparseLoadOpCase2(SparseLoadOp): + def test_2ps_0_load(self): + # init No.1 server env + env = {} + env["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002" + env["PADDLE_TRAINERS_NUM"] = str(2) + env["TRAINING_ROLE"] = "PSERVER" + env["PADDLE_PORT"] = "4002" + env["POD_IP"] = "127.0.0.1" + for k, v in env.items(): + os.environ[k] = str(v) + """ + array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], + [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], + [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6], + [0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7], + [0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8], + [0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]]) + """ + emb_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10) + fc_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10) + model_path = self.save_origin_model(emb_array, fc_array) + + startup_program = fluid.framework.Program() + test_program = fluid.framework.Program() + role = role_maker.PaddleCloudRoleMaker() + fleet.init(role) + loss = self.net(emb_array, fc_array) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.a_sync = True + optimizer = fluid.optimizer.Adam(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(loss) + fleet.init_server(model_path) + emb = np.array(fluid.global_scope().find_var("embedding.block1") + .get_tensor()) + assert emb.all() == emb_array[1::2].all() + shutil.rmtree(model_path) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_adagrad.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..ff545319ccd29caec865b0754b289f0565ad4db0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_adagrad.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet +from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram + + +class TestSparseLoadProgramAdagrad(TestSparseLoadProgram): + """ + Test Sparse load operator. + """ + + def test_server_init(self): + scope, train_program, startup_program, loss = self.net() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + optimizer = fluid.optimizer.Adagrad(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, + self.strategy) + optimizer.minimize(loss) + fleet.init_server() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_adam.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..60c3f7fc9f1264895450c50848730450cbc6425a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_adam.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet +from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram + + +class TestSparseLoadProgramAdam(TestSparseLoadProgram): + """ + Test Sparse load operator. + """ + + def test_server_init(self): + scope, train_program, startup_program, loss = self.net() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + optimizer = fluid.optimizer.Adam(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, + self.strategy) + optimizer.minimize(loss) + fleet.init_server() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_ftrl.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_ftrl.py new file mode 100644 index 0000000000000000000000000000000000000000..fbba08e4e0665789ae567ead5777d150ed32c682 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_ftrl.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet +from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram + + +class TestSparseLoadProgramFtrl(TestSparseLoadProgram): + """ + Test Sparse load operator. + """ + + def test_server_init(self): + scope, train_program, startup_program, loss = self.net() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + optimizer = fluid.optimizer.Ftrl(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, + self.strategy) + optimizer.minimize(loss) + fleet.init_server() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_momentum.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_momentum.py new file mode 100644 index 0000000000000000000000000000000000000000..31635ede6f5d6c6fed139d1d454b4f680b395322 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_momentum.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet +from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram + + +class TestSparseLoadProgramMomentum(TestSparseLoadProgram): + """ + Test Sparse load operator. + """ + + def test_server_init(self): + scope, train_program, startup_program, loss = self.net() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + optimizer = fluid.optimizer.Momentum(1e-3, 0.9) + optimizer = fleet.distributed_optimizer(optimizer, + self.strategy) + optimizer.minimize(loss) + fleet.init_server() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_rmsprop.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_rmsprop.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb5f2a2ea4f18a27bb34d5baacf5890bcbf0ab3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_rmsprop.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet +from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram + + +class TestSparseLoadProgramRmsprop(TestSparseLoadProgram): + """ + Test Sparse load operator. + """ + + def test_server_init(self): + scope, train_program, startup_program, loss = self.net() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + optimizer = fluid.optimizer.RMSProp(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, + self.strategy) + optimizer.minimize(loss) + fleet.init_server() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_sgd.py b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..17bff651c4489df70cad489a212e2609eb9345a7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_sparse_tensor_load_sgd.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +import numpy as np +import tempfile +import shutil +from op_test import OpTest, randomize_probability +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet import fleet + + +class TestSparseLoadProgram(unittest.TestCase): + """ + Test Sparse load operator. + """ + + def setUp(self): + os.environ[ + "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002" + os.environ["PADDLE_TRAINERS_NUM"] = str(2) + os.environ["TRAINING_ROLE"] = "PSERVER" + os.environ["PADDLE_PORT"] = "4001" + os.environ["POD_IP"] = "127.0.0.1" + role = role_maker.PaddleCloudRoleMaker() + fleet.init(role) + self.strategy = paddle.distributed.fleet.DistributedStrategy() + self.strategy.a_sync = True + + def net(self): + train_program = fluid.Program() + startup_program = fluid.Program() + scope = fluid.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + inputs = fluid.data('input', shape=[None, 1], dtype="int64") + emb = fluid.layers.embedding( + inputs, is_sparse=True, size=[10000, 128]) + fc1 = fluid.layers.fc(input=emb, size=128, act="relu") + fc2 = fluid.layers.fc(input=fc1, size=64, act="relu") + loss = fluid.layers.reduce_mean(fc2) + return scope, train_program, startup_program, loss + + +class TestSparseLoadProgramSGD(TestSparseLoadProgram): + def test_server_init(self): + scope, train_program, startup_program, loss = self.net() + with fluid.scope_guard(scope): + with fluid.program_guard(train_program, startup_program): + optimizer = fluid.optimizer.SGD(1e-3) + optimizer = fleet.distributed_optimizer(optimizer, + self.strategy) + optimizer.minimize(loss) + fleet.init_server() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()