未验证 提交 319d2ba9 编写于 作者: X xujiaqi01 提交者: GitHub

fix fs_client_param bug (#21212)

* fix fs_client_param bug, user can set this config through fleet_desc_file or fleet config
* test=develop
上级 0d17c1b8
...@@ -349,14 +349,22 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -349,14 +349,22 @@ class DistributedAdam(DistributedOptimizerImplBase):
tp = ps_param.trainer_param.add() tp = ps_param.trainer_param.add()
tp.CopyFrom(prog_id_to_worker[k].get_desc()) tp.CopyFrom(prog_id_to_worker[k].get_desc())
ps_param.fs_client_param.uri = \ if strategy.get("fs_uri") is not None:
strategy.get("fs_uri", "hdfs://your_hdfs_uri") ps_param.fs_client_param.uri = strategy["fs_uri"]
ps_param.fs_client_param.user = \ elif ps_param.fs_client_param.uri == "":
strategy.get("fs_user", "your_hdfs_user") ps_param.fs_client_param.uri = "hdfs://your_hdfs_uri"
ps_param.fs_client_param.passwd = \ if strategy.get("fs_user") is not None:
strategy.get("fs_passwd", "your_hdfs_passwd") ps_param.fs_client_param.user = strategy["fs_user"]
ps_param.fs_client_param.hadoop_bin = \ elif ps_param.fs_client_param.user == "":
strategy.get("fs_hadoop_bin", "$HADOOP_HOME/bin/hadoop") ps_param.fs_client_param.user = "your_hdfs_user"
if strategy.get("fs_passwd") is not None:
ps_param.fs_client_param.passwd = strategy["fs_passwd"]
elif ps_param.fs_client_param.passwd == "":
ps_param.fs_client_param.passwd = "your_hdfs_passwd"
if strategy.get("fs_hadoop_bin") is not None:
ps_param.fs_client_param.hadoop_bin = strategy["fs_hadoop_bin"]
elif ps_param.fs_client_param.hadoop_bin == "":
ps_param.fs_client_param.hadoop_bin = "$HADOOP_HOME/bin/hadoop"
opt_info = {} opt_info = {}
opt_info["program_id_to_worker"] = prog_id_to_worker opt_info["program_id_to_worker"] = prog_id_to_worker
......
...@@ -479,14 +479,28 @@ class TestDataset(unittest.TestCase): ...@@ -479,14 +479,28 @@ class TestDataset(unittest.TestCase):
class TestDatasetWithDataLoader(TestDataset): class TestDatasetWithDataLoader(TestDataset):
"""
Test Dataset With Data Loader class. TestCases.
"""
def setUp(self): def setUp(self):
"""
Test Dataset With Data Loader, setUp.
"""
self.use_data_loader = True self.use_data_loader = True
self.epoch_num = 10 self.epoch_num = 10
self.drop_last = False self.drop_last = False
class TestDatasetWithFetchHandler(unittest.TestCase): class TestDatasetWithFetchHandler(unittest.TestCase):
"""
Test Dataset With Fetch Handler. TestCases.
"""
def net(self): def net(self):
"""
Test Dataset With Fetch Handler. TestCases.
"""
slots = ["slot1", "slot2", "slot3", "slot4"] slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = [] slots_vars = []
poolings = [] poolings = []
...@@ -504,6 +518,13 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -504,6 +518,13 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
return slots_vars, fc return slots_vars, fc
def get_dataset(self, inputs, files): def get_dataset(self, inputs, files):
"""
Test Dataset With Fetch Handler. TestCases.
Args:
inputs(list): inputs of get_dataset
files(list): files of get_dataset
"""
dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.set_batch_size(32) dataset.set_batch_size(32)
dataset.set_thread(3) dataset.set_thread(3)
...@@ -513,6 +534,9 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -513,6 +534,9 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
return dataset return dataset
def setUp(self): def setUp(self):
"""
Test Dataset With Fetch Handler. TestCases.
"""
with open("test_queue_dataset_run_a.txt", "w") as f: with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
...@@ -526,10 +550,16 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -526,10 +550,16 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
f.write(data) f.write(data)
def tearDown(self): def tearDown(self):
"""
Test Dataset With Fetch Handler. TestCases.
"""
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
def test_dataset_none(self): def test_dataset_none(self):
"""
Test Dataset With Fetch Handler. TestCases.
"""
slots_vars, out = self.net() slots_vars, out = self.net()
files = ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"] files = ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]
dataset = self.get_dataset(slots_vars, files) dataset = self.get_dataset(slots_vars, files)
...@@ -549,6 +579,9 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -549,6 +579,9 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
self.assertTrue(False) self.assertTrue(False)
def test_infer_from_dataset(self): def test_infer_from_dataset(self):
"""
Test Dataset With Fetch Handler. TestCases.
"""
slots_vars, out = self.net() slots_vars, out = self.net()
files = ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"] files = ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]
dataset = self.get_dataset(slots_vars, files) dataset = self.get_dataset(slots_vars, files)
...@@ -564,6 +597,9 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -564,6 +597,9 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
self.assertTrue(False) self.assertTrue(False)
def test_fetch_handler(self): def test_fetch_handler(self):
"""
Test Dataset With Fetch Handler. TestCases.
"""
slots_vars, out = self.net() slots_vars, out = self.net()
files = ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"] files = ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]
dataset = self.get_dataset(slots_vars, files) dataset = self.get_dataset(slots_vars, files)
...@@ -588,5 +624,146 @@ class TestDatasetWithFetchHandler(unittest.TestCase): ...@@ -588,5 +624,146 @@ class TestDatasetWithFetchHandler(unittest.TestCase):
self.assertTrue(False) self.assertTrue(False)
class TestDataset2(unittest.TestCase):
""" TestCases for Dataset. """
def setUp(self):
""" TestCases for Dataset. """
self.use_data_loader = False
self.epoch_num = 10
self.drop_last = False
def test_dataset_fleet(self):
"""
Testcase for InMemoryDataset from create to run.
"""
with open("test_in_memory_dataset2_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset2_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
with fluid.program_guard(train_program, startup_program):
slots = ["slot1_ff", "slot2_ff", "slot3_ff", "slot4_ff"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(\
name=slot, shape=[1], dtype="float32", lod_level=1)
slots_vars.append(var)
fake_cost = \
fluid.layers.elementwise_sub(slots_vars[0], slots_vars[-1])
fake_cost = fluid.layers.mean(fake_cost)
with fluid.scope_guard(scope):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
try:
fleet.init(exe)
except ImportError as e:
print("warning: no mpi4py")
adam = fluid.optimizer.Adam(learning_rate=0.000005)
try:
adam = fleet.distributed_optimizer(adam)
adam.minimize([fake_cost], [scope])
except AttributeError as e:
print("warning: no mpi")
except ImportError as e:
print("warning: no mpi4py")
exe.run(startup_program)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist([
"test_in_memory_dataset2_run_a.txt",
"test_in_memory_dataset2_run_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
fleet._opt_info = None
fleet._fleet_ptr = None
os.remove("./test_in_memory_dataset2_run_a.txt")
os.remove("./test_in_memory_dataset2_run_b.txt")
def test_dataset_fleet2(self):
"""
Testcase for InMemoryDataset from create to run.
"""
with open("test_in_memory_dataset2_run2_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset2_run2_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
with fluid.program_guard(train_program, startup_program):
slots = ["slot1_ff", "slot2_ff", "slot3_ff", "slot4_ff"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(\
name=slot, shape=[1], dtype="float32", lod_level=1)
slots_vars.append(var)
fake_cost = \
fluid.layers.elementwise_sub(slots_vars[0], slots_vars[-1])
fake_cost = fluid.layers.mean(fake_cost)
with fluid.scope_guard(scope):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
try:
fleet.init(exe)
except ImportError as e:
print("warning: no mpi4py")
adam = fluid.optimizer.Adam(learning_rate=0.000005)
try:
adam = fleet.distributed_optimizer(
adam,
strategy={
"fs_uri": "fs_uri_xxx",
"fs_user": "fs_user_xxx",
"fs_passwd": "fs_passwd_xxx",
"fs_hadoop_bin": "fs_hadoop_bin_xxx"
})
adam.minimize([fake_cost], [scope])
except AttributeError as e:
print("warning: no mpi")
except ImportError as e:
print("warning: no mpi4py")
exe.run(startup_program)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist([
"test_in_memory_dataset2_run2_a.txt",
"test_in_memory_dataset2_run2_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
fleet._opt_info = None
fleet._fleet_ptr = None
os.remove("./test_in_memory_dataset2_run2_a.txt")
os.remove("./test_in_memory_dataset2_run2_b.txt")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册