diff --git a/doc/v2/howto/cluster/multi_cluster/index_en.rst b/doc/v2/howto/cluster/multi_cluster/index_en.rst index b69bd5b2dbf1967d65558da06812d76f431c1d5a..9bc1eb2e3796d95dd69b165e916e263ea34b87f6 100644 --- a/doc/v2/howto/cluster/multi_cluster/index_en.rst +++ b/doc/v2/howto/cluster/multi_cluster/index_en.rst @@ -8,28 +8,28 @@ The user's cluster environment is not the same. To facilitate everyone's deploym .. toctree:: :maxdepth: 1 - k8s_cn.md - k8s_distributed_cn.md + k8s_en.md + k8s_distributed_en.md `OpenMPI `_ is a mature high-performance parallel computing framework, which is widely used in the field of HPC. The following guide describes how to use OpenMPI to build PaddlePaddle's cluster training task: .. toctree:: :maxdepth: 1 - openmpi_cn.md + openmpi_en.md `Fabric `_ is a convenient tool for program deployment and management. We provide a way to deploy and manage with Fabric. If you want to know more about it, please read the following guidelines: .. toctree:: :maxdepth: 1 - fabric_cn.md + fabric_en.md We also support the deployment of PaddlePaddle on AWS. Learn more about: .. toctree:: :maxdepth: 1 - k8s_aws_cn.md + k8s_aws_en.md -The examples can be found under `cluster_train_v2 `_ . \ No newline at end of file +The examples can be found under `cluster_train_v2 `_ . diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index 7cfeaab35b8c52225ff6e6cc2cdb8296621b30d9..2405f33d4f0ad83611e57d07a47e787eab439285 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -35,6 +35,16 @@ __forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line, #define FULL_WARP_MASK 0xFFFFFFFF #define CREATE_SHFL_MASK(mask, predicate) \ mask = __ballot_sync(FULL_WARP_MASK, (predicate)) +template +__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) { + return __shfl_down_sync(mask, val, delta); +} + +template +__forceinline__ __device__ T __shfl_sync(unsigned mask, T val, int src_line, + int width) { + return __shfl_sync(mask, val, src_line, width); +} #endif template diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index bd325bd2574afa9c652a9613bdb3bf0b6a93b4a3..dcf4e2a8e013f8e4e70ac1335890e7df0a050b5f 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -21,14 +21,15 @@ import executor from executor import * import trainer -from trainer import * +from trainer import Trainer +from trainer import BeginEpochEvent +from trainer import EndEpochEvent +from trainer import BeginStepEvent +from trainer import EndStepEvent import inferencer from inferencer import Inferencer -import params -from params import Params - import io import evaluator import initializer @@ -57,7 +58,7 @@ from parallel_executor import ParallelExecutor Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ - trainer.__all__ + inferencer.__all__ + params.__all__ + [ + trainer.__all__ + inferencer.__all__ + [ 'io', 'initializer', 'layers', diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index 3ea50bf196d00152e6579623c981ecbfb57b8e3b..58e027695a7100245dd424583e2cedeed3d165e6 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import core + __all__ = ['Inferencer', ] class Inferencer(object): - def __init__(self, network_func, params, place=None): + def __init__(self, network_func, param_path=None, place=None): # 1. we need to generate a framework.Program by calling # network_func. Reference: fluid.program_guard in test_word2vec.py # 2. move the default_main_program to self.program. # 3. run the default_startup program. - self.params = params + + # 4. load params from param_path into scope + self.scope = core.Scope() self.place = place def infer(self, inputs): diff --git a/python/paddle/fluid/params.py b/python/paddle/fluid/params.py deleted file mode 100644 index a5d257e53a2958acd1b8f6ef29d0f9f531b36678..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/params.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import core - -__all__ = ['Params', ] - - -class Params(object): - def __init__(self, path=None): - self.scope = core.Scope() - - if path: - self._load(path) - - def _load(self, path): - # reference: load_persistables in io.py - pass - - def save(self, path): - # reference: save_persistables in io.py - pass - - def add_params(self, scope): - # take the keys from the scope, - # if not already exists in self.scope, - # add the key and value into self.scope. - pass diff --git a/python/paddle/fluid/tests/book/understand_sentiment/notest_understand_sentiment_stacked_lstm.py b/python/paddle/fluid/tests/book/understand_sentiment/notest_understand_sentiment_stacked_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..9948e5c0234ed78237c94f9a25d6401619267d0d --- /dev/null +++ b/python/paddle/fluid/tests/book/understand_sentiment/notest_understand_sentiment_stacked_lstm.py @@ -0,0 +1,140 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +from functools import partial + +CLASS_DIM = 2 +EMB_DIM = 128 +HID_DIM = 512 +STACKED_NUM = 3 + + +def stacked_lstm_net(data, input_dim, class_dim, emb_dim, hid_dim, stacked_num): + assert stacked_num % 2 == 1 + + emb = fluid.layers.embedding( + input=data, size=[input_dim, emb_dim], is_sparse=True) + + fc1 = fluid.layers.fc(input=emb, size=hid_dim) + lstm1, cell1 = fluid.layers.dynamic_lstm(input=fc1, size=hid_dim) + + inputs = [fc1, lstm1] + + for i in range(2, stacked_num + 1): + fc = fluid.layers.fc(input=inputs, size=hid_dim) + lstm, cell = fluid.layers.dynamic_lstm( + input=fc, size=hid_dim, is_reverse=(i % 2) == 0) + inputs = [fc, lstm] + + fc_last = fluid.layers.sequence_pool(input=inputs[0], pool_type='max') + lstm_last = fluid.layers.sequence_pool(input=inputs[1], pool_type='max') + + prediction = fluid.layers.fc(input=[fc_last, lstm_last], + size=class_dim, + act='softmax') + return prediction + + +def inference_network(word_dict): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + + dict_dim = len(word_dict) + net = stacked_lstm_net(data, dict_dim, CLASS_DIM, EMB_DIM, HID_DIM, + STACKED_NUM) + return net + + +def train_network(word_dict): + prediction = inference_network(word_dict) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(cost) + accuracy = fluid.layers.accuracy(input=prediction, label=label) + return avg_cost, accuracy + + +def train(use_cuda, save_path): + BATCH_SIZE = 128 + EPOCH_NUM = 5 + + word_dict = paddle.dataset.imdb.word_dict() + + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.imdb.train(word_dict), buf_size=1000), + batch_size=BATCH_SIZE) + + test_data = paddle.batch( + paddle.dataset.imdb.test(word_dict), batch_size=BATCH_SIZE) + + def event_handler(event): + if isinstance(event, fluid.EndIteration): + if (event.batch_id % 10) == 0: + avg_cost, accuracy = trainer.test(reader=test_data) + + print('BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'.format( + event.batch_id + 1, avg_cost, accuracy)) + + if accuracy > 0.01: # Low threshold for speeding up CI + trainer.params.save(save_path) + return + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + trainer = fluid.Trainer( + partial(train_network, word_dict), + optimizer=fluid.optimizer.Adagrad(learning_rate=0.002), + place=place, + event_handler=event_handler) + + trainer.train(train_data, EPOCH_NUM, event_handler=event_handler) + + +def infer(use_cuda, save_path): + params = fluid.Params(save_path) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + word_dict = paddle.dataset.imdb.word_dict() + inferencer = fluid.Inferencer( + partial(inference_network, word_dict), params, place=place) + + def create_random_lodtensor(lod, place, low, high): + data = np.random.random_integers(low, high, + [lod[-1], 1]).astype("int64") + res = fluid.LoDTensor() + res.set(data, place) + res.set_lod([lod]) + return res + + lod = [0, 4, 10] + tensor_words = create_random_lodtensor( + lod, place, low=0, high=len(word_dict) - 1) + results = inferencer.infer({'words': tensor_words}) + print("infer results: ", results) + + +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + save_path = "understand_sentiment_stacked_lstm.inference.model" + train(use_cuda, save_path) + infer(use_cuda, save_path) + + +if __name__ == '__main__': + for use_cuda in (False, True): + main(use_cuda=use_cuda) diff --git a/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py b/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py index 30939cae29ddf7070f3e5d39f33dc88e86c65450..35e163dc9df5a35ee5774b6b157366c4eabcb0f7 100644 --- a/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py +++ b/python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py @@ -39,7 +39,7 @@ word_dict = paddle.dataset.imikolov.build_dict() dict_size = len(word_dict) -def inference_network(is_sparse): +def inference_program(is_sparse): first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64') second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64') third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64') @@ -79,9 +79,9 @@ def inference_network(is_sparse): return predict_word -def train_network(is_sparse): +def train_program(is_sparse): next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64') - predict_word = inference_network(is_sparse) + predict_word = inference_program(is_sparse) cost = fluid.layers.cross_entropy(input=predict_word, label=next_word) avg_cost = fluid.layers.mean(cost) return avg_cost @@ -100,23 +100,25 @@ def train(use_cuda, is_sparse, save_path): word_dict, N)) if avg_cost < 5.0: - trainer.params.save(save_path) + trainer.save_params(save_path) return if math.isnan(avg_cost): sys.exit("got NaN loss, training failed.") trainer = fluid.Trainer( - partial(train_network, is_sparse), + partial(train_program, is_sparse), fluid.optimizer.SGD(learning_rate=0.001), place=place) trainer.train( reader=train_reader, num_epochs=100, event_handler=event_handler) -def infer(use_cuda, save_path): - params = fluid.Params(save_path) +def infer(use_cuda, is_sparse, save_path): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - inferencer = fluid.Inferencer(inference_network, params, place=place) + inferencer = fluid.Inferencer( + partial(inference_program, is_sparse), + param_path=save_path, + place=place) lod = [0, 1] first_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1) @@ -138,7 +140,7 @@ def main(use_cuda, is_sparse): save_path = "word2vec.inference.model" train(use_cuda, is_sparse, save_path) - infer(use_cuda, save_path) + infer(use_cuda, is_sparse, save_path) if __name__ == '__main__': diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 2362da370a39b6e59f486e2affbb02e21840784f..0aada3deb0f047d21701b64af022ebad372d505b 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -56,23 +56,22 @@ class Trainer(object): """ Args: - network_func(callable): A function which will return loss. The loss must be a scaler. + program_func(callable): A function which will return loss. The loss must be a scaler. optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer - params: place: The device place of this trainer. """ - def __init__(self, network_func, optimizer, params=None, place=None): + def __init__(self, program_func, optimizer, param_path=None, place=None): # 1. we need to generate a framework.Program by calling - # network_func. Reference: fluid.program_guard in + # program_func. Reference: fluid.program_guard in # test_word2vec.py - self.scope = self._get_scope_from_params(params) + self.scope = core.Scope() self.startup_program = framework.Program() self.train_program = framework.Program() with framework.program_guard(self.train_program, self.startup_program): - loss = network_func() + loss = program_func() if not isinstance(optimizer, opt_module.Optimizer): raise TypeError( "The optimizer should be an instance of Optimizer") @@ -84,14 +83,13 @@ class Trainer(object): # 2. move the default_main_program to self.program and run the # default_startup program on an empty core.Scope() # Run startup program - if params is None: - exe = executor.Executor(place) - exe.run(self.startup_program, scope=self.scope) + exe = executor.Executor(place) + exe.run(self.startup_program, scope=self.scope) - # 3. call self.params.add_vars with the initialized scope, it - # will add the new vars of the initialized scope into - # self.params. - # TODO(yuyang): This depends on parameters implementation. + if param_path: + # load params from param_path into scope + # TODO(yuyang): This depends on parameters implementation. + pass # TODO(helin): support distributed training @@ -124,19 +122,9 @@ class Trainer(object): def test(self, reader): pass - def _get_scope_from_params(self, params): - """ - Get Scope from parameter object. - Args: - params(Parameter|None): The parameter object instance. Could be None. - - Returns: New scope if params is None. Or params.scope() - NOTE: This method is WIP. Not fully implemented. - """ - if params is None: - return core.Scope() # new scope when params is None - else: - raise NotImplementedError("Not implemented right now.") + def save_params(self, param_path): + # reference: save_persistables in io.py + pass @staticmethod def _check_and_get_place(place):