未验证 提交 328cb289 编写于 作者: C Chengmo 提交者: GitHub

【paddle.fleet】fix sparse load (#27680)

* add sparse tensor load method
上级 cf70d5b3
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
struct DeserializedDataFunctor {
DeserializedDataFunctor(void **buf, Tensor *tensor,
const platform::Place &place)
: buf_(buf), tensor_(tensor), place_(place) {}
template <typename T>
void apply() {
*buf_ = tensor_->mutable_data<T>(place_);
}
void **buf_;
Tensor *tensor_;
platform::Place place_;
};
template <typename DeviceContext, typename T>
class SparseTensorLoadKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
auto name = ctx.OutputNames("Out")[0];
VLOG(4) << "Sparse Load Var name: " << name;
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.", name));
PADDLE_ENFORCE_EQ(out_var->IsType<paddle::framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"SparseLoad OP only support LoDTensor"));
LoadLodTensor(fin, place, out_var, ctx);
}
void LoadLodTensor(std::istream &is, const platform::Place &place,
paddle::framework::Variable *var,
const paddle::framework::ExecutionContext &ctx) const {
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
auto node_index = ctx.Attr<int64_t>("node_index");
auto node_num = ctx.Attr<int64_t>("node_num");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
VLOG(4) << "Sparse LoadLodTensor node_num" << node_num;
VLOG(4) << "Sparse LoadLodTensor node_index" << node_index;
VLOG(4) << "Sparse LoadLodTensor shape[0]" << shape[0];
PADDLE_ENFORCE_GE(node_index, 0, platform::errors::InvalidArgument(
"node_num great than or equal to 0"));
PADDLE_ENFORCE_GE(node_num, 1, platform::errors::InvalidArgument(
"node_num great than or equal to 1"));
{
// the 1st field, unit32_t version for LoDTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
true,
platform::errors::InvalidArgument(
"Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(version, 0U, platform::errors::InvalidArgument(
"Tensor version %u is not supported, "
"only version 0 is supported.",
version));
}
{
// the 2st field, LoD information
// Todo sparse load need change LoDTensor's lod level
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
}
// the 3st filed, Tensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported",
version));
paddle::framework::proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE_EQ(
desc.ParseFromArray(buf.get(), size), true,
platform::errors::InvalidArgument("Cannot parse tensor desc"));
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(),
std::back_inserter(dims));
int64_t line_numel = 1;
for (size_t dim = 1; dim < dims.size(); dim++) {
line_numel *= dims[dim];
}
auto total_line = dims[0];
tensor->Resize(paddle::framework::make_ddim(shape));
void *buf;
auto ctx = platform::CPUDeviceContext();
paddle::framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
auto line_size =
line_numel * paddle::framework::SizeOfType(desc.data_type());
char *cur_buf = static_cast<char *>(buf);
char *temp_row = new char[line_size];
VLOG(4) << "TensorFromStream: line_size " << line_size;
VLOG(4) << "TensorFromStream: total_line " << total_line;
for (size_t line_index = 0; line_index < static_cast<size_t>(total_line);
++line_index) {
is.read(temp_row, line_size);
if (static_cast<int64_t>(line_index) % node_num == node_index) {
memcpy(cur_buf, temp_row, line_size);
cur_buf += line_size;
}
}
}
}
};
class SparseTensorLoadOp : public paddle::framework::OperatorWithKernel {
public:
using paddle::framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(paddle::framework::InferShapeContext *ctx) const override {}
protected:
paddle::framework::OpKernelType GetExpectedKernelType(
const paddle::framework::ExecutionContext &ctx) const override {
paddle::framework::OpKernelType kt = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::FP32, ctx.GetPlace());
return kt;
}
};
class SparseTensorLoadOpMaker
: public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddAttr<std::string>("file_path",
R"(Variable will be loaded from "file_path")")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddAttr<int64_t>("node_index", "role id from 0 ~ node_num.").SetDefault(0);
AddAttr<int64_t>("node_num", "role nums which need load current varibale.")
.SetDefault(0);
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddComment(R"DOC(
SparseTensorLoad OP, Load sprase tensor on parameter server
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sparse_tensor_load, ops::SparseTensorLoadOp,
ops::SparseTensorLoadOpMaker);
REGISTER_OP_CPU_KERNEL(
sparse_tensor_load,
ops::SparseTensorLoadKernel<paddle::platform::CPUDeviceContext, float>)
......@@ -31,6 +31,10 @@ class DGCOptimizer(MetaOptimizerBase):
loss, role_maker, user_defined_optimizer, user_defined_strategy)
opt = self.inner_opt
if not self.role_maker._is_collective:
return
if not isinstance(opt, Momentum):
return
......
......@@ -21,6 +21,7 @@ from paddle.fluid.framework import Program
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.framework import Variable, Parameter
from .runtime_base import RuntimeBase
from ..base.private_helper_function import wait_server_ready
......@@ -69,7 +70,52 @@ class ParameterServerRuntime(RuntimeBase):
self.async_strategy, self.role_maker)
return compiled_config
def _load_sparse_params(self, dirname, varnames):
def _load_sparse_params(self,
executor,
dirname,
varnames,
main_program=None):
assert vars != None
check_vars = []
load_prog = Program()
load_block = load_prog.global_block()
def _in_varnames(var):
return var.name in varnames
load_vars = list(
filter(_in_varnames, fluid.default_main_program().list_vars()))
if main_program is None:
main_program = self.origin_main_program
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_varname_parts
for each_var in load_vars:
assert isinstance(each_var, Variable)
origin_varname, _, _ = _get_varname_parts(each_var.name)
new_var = fluid.io._clone_var_in_block_(load_block, each_var)
var_path = os.path.join(dirname, origin_varname)
if not os.path.exists(var_path):
raise ValueError("SelectedRows var {} can not find at {}".
format(new_var.name, var_path))
if os.path.isfile(var_path):
load_block.append_op(
type='sparse_tensor_load',
inputs={},
outputs={'Out': [new_var]},
attrs={
'file_path': os.path.join(dirname, origin_varname),
'node_index': self.role_maker._server_index(),
'node_num': self.role_maker._server_num(),
'shape': each_var.shape
})
check_vars.append(each_var)
executor.run(load_prog)
def _load_distributed_params(self, dirname, varnames):
from paddle.fluid.communicator import LargeScaleKV
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_varname_parts
......@@ -248,34 +294,54 @@ class ParameterServerRuntime(RuntimeBase):
self._init_worker()
return
if not model_dirname:
return
if not os.path.isdir(model_dirname):
raise ValueError("There is no directory named '%s'", model_dirname)
sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(True)
sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False)
sparse_related_optimize_varnames = []
for var_name in sparse_varnames:
sparse_related_optimize_varnames += self.compiled_strategy.get_optimize_varname_on_ps(
var_name)
sparse_related_optimize_varnames = list(
set(sparse_related_optimize_varnames))
distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps(
False)
True)
distributed_related_optimize_varnames = []
for var_name in distribtued_varnames:
distributed_related_optimize_varnames += self.compiled_strategy.get_optimize_varname_on_ps(
var_name)
distributed_related_optimize_varnames = list(
set(distributed_related_optimize_varnames))
remaining_vars = list(
filter(
ParameterServerRuntime.__exclude_vars(sparse_varnames +
distribtued_varnames),
ParameterServerRuntime.__exclude_vars(
sparse_varnames + distribtued_varnames +
sparse_related_optimize_varnames +
distributed_related_optimize_varnames),
fluid.default_main_program().list_vars()))
if not model_dirname:
return
if not os.path.isdir(model_dirname):
raise ValueError("There is no directory named '%s'", model_dirname)
# load dense
fluid.io.load_vars(
executor,
main_program=fluid.default_main_program(),
dirname=model_dirname,
vars=remaining_vars)
# load sparse
self._load_sparse_params(
dirname=model_dirname, varnames=sparse_varnames)
executor=executor,
dirname=model_dirname,
varnames=sparse_varnames + sparse_related_optimize_varnames)
# todo(tangwei12) load distributed vars
# self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames)
# load large scale
self._load_distributed_params(
dirname=model_dirname,
varnames=distribtued_varnames +
distributed_related_optimize_varnames)
def _run_server(self):
executor = self._get_executor()
......
......@@ -236,9 +236,9 @@ class CompileTimeStrategy(object):
def get_sparse_varname_on_ps(self, is_distributed, endpoint=None):
if not endpoint:
endpoint = self.get_ps_endpoint()
varnames = get_sparse_tablenames(self.get_origin_main_program(),
is_distributed)
ps_sparse_varnames = []
for varname in varnames:
tables = self.get_var_distributed(varname, True)
......@@ -248,6 +248,55 @@ class CompileTimeStrategy(object):
ps_sparse_varnames.append(table)
return ps_sparse_varnames
def get_optimize_varname_on_ps(self, param_name):
origin_param_name, _, _ = _get_varname_parts(param_name)
optimize_var_names = []
for op in self.get_origin_main_program().global_block().ops:
# check all optimizer op
if int(op.all_attrs()["op_role"]) == 2:
# check param name
if op.input("Param")[0] != origin_param_name:
continue
# check all input
for key in op.input_names:
if key in [
"Param", "Grad", "LearningRate", "Beta1Tensor",
"Beta2Tensor"
]:
continue
# check varibale shape related param, e.g: Moment1
optimize_var_names += self._get_optimizer_param_related_var_name(
op, op.type, key)
return optimize_var_names
def _get_optimizer_param_related_var_name(self, op, op_type, varkey):
"""
Returns the names for optimizer inputs that need to be load
"""
related_var_names = []
if op_type == "adam":
if varkey in ["Moment1", "Moment2"]:
related_var_names.append(op.input(varkey)[0])
elif op_type == "adagrad":
if varkey == "Moment":
related_var_names.append(op.input(varkey)[0])
elif op_type in ["momentum", "lars_momentum"]:
if varkey == "Velocity":
related_var_names.append(op.input(varkey)[0])
elif op_type == "rmsprop":
if varkey in ["Moment", "MeanSquare"]:
related_var_names.append(op.input(varkey)[0])
elif op_type == "ftrl":
if varkey in ["SquaredAccumulator", "LinearAccumulator"]:
related_var_names.append(op.input(varkey)[0])
elif op_type == "sgd":
pass
else:
raise ValueError(
"Not supported optimizer for distributed training: %s" %
op_type)
return related_var_names
def build_ctx(self,
vars,
mapping,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
class SparseLoadOp(unittest.TestCase):
""" Test load operator.
"""
def net(self, emb_array, fc_array):
with fluid.unique_name.guard():
dense_input = fluid.data('input', shape=[None, 1], dtype="int64")
emb = fluid.layers.embedding(
input=dense_input,
is_sparse=True,
size=[10, 10],
param_attr=fluid.ParamAttr(
name="embedding",
initializer=fluid.initializer.NumpyArrayInitializer(
emb_array)), )
fc1 = fluid.layers.fc(
input=emb,
size=10,
act="relu",
param_attr=fluid.ParamAttr(
name='fc',
initializer=fluid.initializer.NumpyArrayInitializer(
fc_array)))
loss = fluid.layers.reduce_mean(fc1)
return loss
def save_origin_model(self, emb_array, fc_array):
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
with fluid.framework.program_guard(test_program, startup_program):
with fluid.unique_name.guard():
loss = self.net(emb_array, fc_array)
optimizer = fluid.optimizer.Adam(1e-3)
optimizer.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
model_path = tempfile.mkdtemp()
fluid.io.save_persistables(executor=exe, dirname=model_path)
return model_path
class TestSparseLoadOpCase1(SparseLoadOp):
def test_2ps_0_load(self):
# init No.0 server env
env = {}
env["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002"
env["PADDLE_TRAINERS_NUM"] = str(2)
env["TRAINING_ROLE"] = "PSERVER"
env["PADDLE_PORT"] = "4001"
env["POD_IP"] = "127.0.0.1"
for k, v in env.items():
os.environ[k] = str(v)
"""
array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
[0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
[0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
[0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7],
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8],
[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]])
"""
emb_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
fc_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
model_path = self.save_origin_model(emb_array, fc_array)
role = role_maker.PaddleCloudRoleMaker()
fleet.init(role)
loss = self.net(emb_array, fc_array)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
optimizer = fluid.optimizer.Adam(1e-3)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(loss)
fleet.init_server(model_path)
fc_w = np.array(fluid.global_scope().find_var("fc").get_tensor())
emb = np.array(fluid.global_scope().find_var("embedding.block0")
.get_tensor())
assert fc_w.all() == fc_array.all()
assert emb.all() == emb_array[::2].all()
shutil.rmtree(model_path)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
from test_dist_sparse_load_ps0 import SparseLoadOp
class TestSparseLoadOpCase2(SparseLoadOp):
def test_2ps_0_load(self):
# init No.1 server env
env = {}
env["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002"
env["PADDLE_TRAINERS_NUM"] = str(2)
env["TRAINING_ROLE"] = "PSERVER"
env["PADDLE_PORT"] = "4002"
env["POD_IP"] = "127.0.0.1"
for k, v in env.items():
os.environ[k] = str(v)
"""
array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
[0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
[0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6],
[0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7],
[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8],
[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]])
"""
emb_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
fc_array = np.arange(0, 1, 0.1).repeat(10).reshape(10, 10)
model_path = self.save_origin_model(emb_array, fc_array)
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
role = role_maker.PaddleCloudRoleMaker()
fleet.init(role)
loss = self.net(emb_array, fc_array)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
optimizer = fluid.optimizer.Adam(1e-3)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(loss)
fleet.init_server(model_path)
emb = np.array(fluid.global_scope().find_var("embedding.block1")
.get_tensor())
assert emb.all() == emb_array[1::2].all()
shutil.rmtree(model_path)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
class TestSparseLoadProgramAdagrad(TestSparseLoadProgram):
"""
Test Sparse load operator.
"""
def test_server_init(self):
scope, train_program, startup_program, loss = self.net()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
optimizer = fluid.optimizer.Adagrad(1e-3)
optimizer = fleet.distributed_optimizer(optimizer,
self.strategy)
optimizer.minimize(loss)
fleet.init_server()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
class TestSparseLoadProgramAdam(TestSparseLoadProgram):
"""
Test Sparse load operator.
"""
def test_server_init(self):
scope, train_program, startup_program, loss = self.net()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
optimizer = fluid.optimizer.Adam(1e-3)
optimizer = fleet.distributed_optimizer(optimizer,
self.strategy)
optimizer.minimize(loss)
fleet.init_server()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
class TestSparseLoadProgramFtrl(TestSparseLoadProgram):
"""
Test Sparse load operator.
"""
def test_server_init(self):
scope, train_program, startup_program, loss = self.net()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
optimizer = fluid.optimizer.Ftrl(1e-3)
optimizer = fleet.distributed_optimizer(optimizer,
self.strategy)
optimizer.minimize(loss)
fleet.init_server()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
class TestSparseLoadProgramMomentum(TestSparseLoadProgram):
"""
Test Sparse load operator.
"""
def test_server_init(self):
scope, train_program, startup_program, loss = self.net()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
optimizer = fluid.optimizer.Momentum(1e-3, 0.9)
optimizer = fleet.distributed_optimizer(optimizer,
self.strategy)
optimizer.minimize(loss)
fleet.init_server()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
from test_dist_sparse_tensor_load_sgd import TestSparseLoadProgram
class TestSparseLoadProgramRmsprop(TestSparseLoadProgram):
"""
Test Sparse load operator.
"""
def test_server_init(self):
scope, train_program, startup_program, loss = self.net()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
optimizer = fluid.optimizer.RMSProp(1e-3)
optimizer = fleet.distributed_optimizer(optimizer,
self.strategy)
optimizer.minimize(loss)
fleet.init_server()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import tempfile
import shutil
from op_test import OpTest, randomize_probability
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet import fleet
class TestSparseLoadProgram(unittest.TestCase):
"""
Test Sparse load operator.
"""
def setUp(self):
os.environ[
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:4001,127.0.0.1:4002"
os.environ["PADDLE_TRAINERS_NUM"] = str(2)
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PORT"] = "4001"
os.environ["POD_IP"] = "127.0.0.1"
role = role_maker.PaddleCloudRoleMaker()
fleet.init(role)
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = True
def net(self):
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
inputs = fluid.data('input', shape=[None, 1], dtype="int64")
emb = fluid.layers.embedding(
inputs, is_sparse=True, size=[10000, 128])
fc1 = fluid.layers.fc(input=emb, size=128, act="relu")
fc2 = fluid.layers.fc(input=fc1, size=64, act="relu")
loss = fluid.layers.reduce_mean(fc2)
return scope, train_program, startup_program, loss
class TestSparseLoadProgramSGD(TestSparseLoadProgram):
def test_server_init(self):
scope, train_program, startup_program, loss = self.net()
with fluid.scope_guard(scope):
with fluid.program_guard(train_program, startup_program):
optimizer = fluid.optimizer.SGD(1e-3)
optimizer = fleet.distributed_optimizer(optimizer,
self.strategy)
optimizer.minimize(loss)
fleet.init_server()
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册