提交 61d6db56 编写于 作者: Y yangyaming

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-7555

...@@ -211,3 +211,49 @@ decoder_inputs = paddle.layer.fc( ...@@ -211,3 +211,49 @@ decoder_inputs = paddle.layer.fc(
* list 中元素的个数等于网络中输出层的个数; * list 中元素的个数等于网络中输出层的个数;
* list 中每个元素是一个layer的输出结果矩阵,类型是numpy的ndarray; * list 中每个元素是一个layer的输出结果矩阵,类型是numpy的ndarray;
* 每一个layer输出矩阵的高度,在非序列输入时:等于样本数;序列输入时等于:输入序列中元素的总数;宽度等于配置中layer的size; * 每一个layer输出矩阵的高度,在非序列输入时:等于样本数;序列输入时等于:输入序列中元素的总数;宽度等于配置中layer的size;
6. 如何在训练过程中获得某一个layer的output
-----------------------------------------------
可以在event_handler中,通过 :code:`event.gm.getLayerOutputs("layer_name")` 获得在模型配置中某一层的name :code:`layer_name` 在当前
mini-batch forward的output的值。获得的值类型均为 :code:`numpy.ndarray` ,可以通过这个输出来完成自定义的评估指标计算等功能。例如下面代码:
.. code-block:: python
def score_diff(right_score, left_score):
return np.average(np.abs(right_score - left_score))
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 25 == 0:
diff = score_diff(
event.gm.getLayerOutputs("right_score")["right_score"][
"value"],
event.gm.getLayerOutputs("left_score")["left_score"][
"value"])
logger.info(("Pass %d Batch %d : Cost %.6f, "
"average absolute diff scores: %.6f") %
(event.pass_id, event.batch_id, event.cost, diff))
注意:此方法不能获取 :code:`paddle.layer.recurrent_group` 里step的内容,但可以获取 :code:`paddle.layer.recurrent_group` 的输出。
7. 如何在训练过程中获得参数的权重和梯度
-----------------------------------------------
在某些情况下,获得当前mini-batch的权重(或称作weights, parameters)有助于在训练时观察具体数值,方便排查以及快速定位问题。
可以通过在 :code:`event_handler` 中打印其值(注意,需要使用 :code:`paddle.event.EndForwardBackward` 保证使用GPU训练时也可以获得),
示例代码如下:
.. code-block:: python
...
parameters = paddle.parameters.create(cost)
...
def event_handler(event):
if isinstance(event, paddle.event.EndForwardBackward):
if event.batch_id % 25 == 0:
for p in parameters.keys():
logger.info("Param %s, Grad %s",
parameters.get(p), parameters.get_grad(p))
注意:“在训练过程中获得某一个layer的output”和“在训练过程中获得参数的权重和梯度”都会造成训练中的数据从C++拷贝到numpy,会对训练性能造成影响。不要在注重性能的训练场景下使用。
\ No newline at end of file
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
DECLARE_bool(do_memory_benchmark);
DEFINE_bool(check_nan_inf, false, DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be " "Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely."); "extremely slow so please use this flag wisely.");
...@@ -117,6 +118,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -117,6 +118,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugStringEx(local_scope); VLOG(3) << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (FLAGS_do_memory_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) { for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname); auto* var = local_scope->FindVar(vname);
...@@ -130,6 +135,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -130,6 +135,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
if (create_vars && create_local_scope) { if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} }
if (FLAGS_do_memory_benchmark) {
VLOG(2) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------";
}
} }
} // namespace framework } // namespace framework
......
...@@ -20,6 +20,10 @@ limitations under the License. */ ...@@ -20,6 +20,10 @@ limitations under the License. */
#include "paddle/framework/threadpool.h" #include "paddle/framework/threadpool.h"
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
DEFINE_bool(do_memory_benchmark, false,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -88,8 +92,12 @@ void Scope::DeleteScope(Scope* scope) { ...@@ -88,8 +92,12 @@ void Scope::DeleteScope(Scope* scope) {
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it); this->kids_.erase(it);
// Make delete async. // When making memory benchmark on Fluid, we have to delete scope sync.
if (FLAGS_do_memory_benchmark) {
delete scope;
} else {
Async([scope] { delete scope; }); Async([scope] { delete scope; });
}
} }
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
......
...@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is ...@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s calculated by %s
)DOC", )DOC",
comment.type, comment.equation)); comment.type, comment.equation));
AddAttr<int>("axis",
"(int, default -1). The start dimension index "
"for broadcasting Y onto X.")
.SetDefault(-1)
.EqualGreaterThan(-1);
} }
}; };
...@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); ...@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
...@@ -16,8 +16,4 @@ limitations under the License. */ ...@@ -16,8 +16,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <math.h> #include <math.h>
#include <type_traits> #include <type_traits>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/platform/transform.h" #include "paddle/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -33,18 +34,6 @@ struct LessEqualFunctor { ...@@ -33,18 +34,6 @@ struct LessEqualFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
}; };
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T> template <typename T>
struct EqualFunctor { struct EqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
...@@ -65,14 +54,7 @@ class CompareOpKernel ...@@ -65,14 +54,7 @@ class CompareOpKernel
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE; using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X"); ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context);
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), y->data<T>(),
out->mutable_data<bool>(context.GetPlace()), binary_func);
} }
}; };
......
...@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext> ...@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
}; };
#endif #endif
template <typename Functor, typename T, typename DeviceContext> template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor { class TransformFunctor {
public: public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const DeviceContext& ctx, Functor func) framework::Tensor* z, const DeviceContext& ctx, Functor func)
: x_(x->data<T>()), : x_(x->data<T>()),
y_(y->data<T>()), y_(y->data<T>()),
z_(z->mutable_data<T>(ctx.GetPlace())), z_(z->mutable_data<OutType>(ctx.GetPlace())),
nx_(x->numel()), nx_(x->numel()),
ctx_(ctx), ctx_(ctx),
func_(func) {} func_(func) {}
...@@ -208,7 +209,7 @@ class TransformFunctor { ...@@ -208,7 +209,7 @@ class TransformFunctor {
private: private:
const T* x_; const T* x_;
const T* y_; const T* y_;
T* z_; OutType* z_;
int64_t nx_; int64_t nx_;
const DeviceContext& ctx_; const DeviceContext& ctx_;
Functor func_; Functor func_;
...@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
} }
} }
template <typename Functor, typename DeviceContext, typename T> template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<OutType>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor()); x, y, z, ctx.template device_context<DeviceContext>(), Functor());
auto x_dims = x->dims(); auto x_dims = x->dims();
......
...@@ -24,11 +24,23 @@ import conll05 ...@@ -24,11 +24,23 @@ import conll05
import uci_housing import uci_housing
import sentiment import sentiment
import wmt14 import wmt14
import wmt16
import mq2007 import mq2007
import flowers import flowers
import voc2012 import voc2012
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist',
'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' 'imikolov',
'imdb',
'cifar',
'movielens',
'conll05',
'sentiment'
'uci_housing',
'wmt14',
'wmt16',
'mq2007',
'flowers',
'voc2012',
] ]
...@@ -25,8 +25,12 @@ import glob ...@@ -25,8 +25,12 @@ import glob
import cPickle as pickle import cPickle as pickle
__all__ = [ __all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader', 'DATA_HOME',
'convert' 'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
] ]
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
...@@ -58,12 +62,15 @@ def md5file(fname): ...@@ -58,12 +62,15 @@ def md5file(fname):
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download(url, module_name, md5sum): def download(url, module_name, md5sum, save_name=None):
dirname = os.path.join(DATA_HOME, module_name) dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1]) filename = os.path.join(dirname,
url.split('/')[-1]
if save_name is None else save_name)
retry = 0 retry = 0
retry_limit = 3 retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum): while not (os.path.exists(filename) and md5file(filename) == md5sum):
...@@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix): ...@@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix):
Convert data from reader to recordio format files. Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved. :param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances. :param reader: a data reader, from which the convert program will read
data instances.
:param name_prefix: the name prefix of generated files. :param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing. :param max_lines_to_shuffle: the max lines numbers to shuffle before
writing.
""" """
assert line_count >= 1 assert line_count >= 1
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.wmt16
import unittest
class TestWMT16(unittest.TestCase):
def checkout_one_sample(self, sample):
# train data has 3 field: source language word indices,
# target language word indices, and target next word indices.
self.assertEqual(len(sample), 3)
# test start mark and end mark in source word indices.
self.assertEqual(sample[0][0], 0)
self.assertEqual(sample[0][-1], 1)
# test start mask in target word indices
self.assertEqual(sample[1][0], 0)
# test en mask in target next word indices
self.assertEqual(sample[2][-1], 1)
def test_train(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.train(
src_dict_size=100000, trg_dict_size=100000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_test(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.test(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_val(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.validation(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_get_dict(self):
dict_size = 1000
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
self.assertEqual(len(word_dict), dict_size)
self.assertEqual(word_dict[0], "<s>")
self.assertEqual(word_dict[1], "<e>")
self.assertEqual(word_dict[2], "<unk>")
if __name__ == "__main__":
unittest.main()
...@@ -25,12 +25,20 @@ import gzip ...@@ -25,12 +25,20 @@ import gzip
import paddle.v2.dataset.common import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters from paddle.v2.parameters import Parameters
__all__ = ['train', 'test', 'build_dict', 'convert'] __all__ = [
'train',
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' 'test',
'get_dict',
'convert',
]
URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz')
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later. # this is a small set of data for test. The original data is too large and
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' # will be add later.
URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz')
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92 # BLEU of this trained model is 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
...@@ -42,8 +50,8 @@ UNK = "<unk>" ...@@ -42,8 +50,8 @@ UNK = "<unk>"
UNK_IDX = 2 UNK_IDX = 2
def __read_to_dict__(tar_file, dict_size): def __read_to_dict(tar_file, dict_size):
def __to_dict__(fd, size): def __to_dict(fd, size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
if line_count < size: if line_count < size:
...@@ -58,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size): ...@@ -58,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size):
if each_item.name.endswith("src.dict") if each_item.name.endswith("src.dict")
] ]
assert len(names) == 1 assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size) src_dict = __to_dict(f.extractfile(names[0]), dict_size)
names = [ names = [
each_item.name for each_item in f each_item.name for each_item in f
if each_item.name.endswith("trg.dict") if each_item.name.endswith("trg.dict")
] ]
assert len(names) == 1 assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size) trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict return src_dict, trg_dict
def reader_creator(tar_file, file_name, dict_size): def reader_creator(tar_file, file_name, dict_size):
def reader(): def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f: with tarfile.open(tar_file, mode='r') as f:
names = [ names = [
each_item.name for each_item in f each_item.name for each_item in f
...@@ -152,7 +160,7 @@ def get_dict(dict_size, reverse=True): ...@@ -152,7 +160,7 @@ def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...} # if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...} # else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in src_dict.items()} src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()} trg_dict = {v: k for k, v in trg_dict.items()}
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ACL2016 Multimodal Machine Translation. Please see this website for more
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions.
@article{elliott-EtAl:2016:VL16,
author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
title = {Multi30K: Multilingual English-German Image Descriptions},
booktitle = {Proceedings of the 6th Workshop on Vision and Language},
year = {2016},
pages = {70--74},
year = 2016
}
"""
import os
import tarfile
import gzip
from collections import defaultdict
import paddle.v2.dataset.common
__all__ = [
"train",
"test",
"validation",
"convert",
"fetch",
"get_dict",
]
DATA_URL = ("http://cloud.dlnel.org/filepub/"
"?uuid=46a0808e-ddd8-427c-bacd-0dbc6d045fed")
DATA_MD5 = "0c38be43600334966403524a40dcd81e"
TOTAL_EN_WORDS = 11250
TOTAL_DE_WORDS = 19220
START_MARK = "<s>"
END_MARK = "<e>"
UNK_MARK = "<unk>"
def __build_dict(tar_file, dict_size, save_path, lang):
word_dict = defaultdict(int)
with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile("wmt16/train"):
line_split = line.strip().split("\t")
if len(line_split) != 2: continue
sen = line_split[0] if lang == "en" else line_split[1]
for w in sen.split():
word_dict[w] += 1
with open(save_path, "w") as fout:
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate(
sorted(
word_dict.iteritems(), key=lambda x: x[1], reverse=True)):
if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0]))
def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size):
__build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {}
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip()
else:
word_dict[line.strip()] = idx
return word_dict
def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
TOTAL_ENG_WORDS))
return src_dict_size, trg_dict_size
def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
def reader():
src_dict = __load_dict(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict(tar_file, trg_dict_size,
("de" if src_lang == "en" else "en"))
# the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
start_id = src_dict[START_MARK]
end_id = src_dict[END_MARK]
unk_id = src_dict[UNK_MARK]
src_col = 0 if src_lang == "en" else 1
trg_col = 1 - src_col
with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile(file_name):
line_split = line.strip().split("\t")
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
src_ids = [start_id] + [
src_dict.get(w, unk_id) for w in src_words
] + [end_id]
trg_words = line_split[trg_col].split()
trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]
trg_ids_next = trg_ids + [end_id]
trg_ids = [start_id] + trg_ids
yield src_ids, trg_ids, trg_ids_next
return reader
def train(src_dict_size, trg_dict_size, src_lang="en"):
"""
WMT16 train set reader.
This function returns the reader for train data. Each sample the reader
returns is made up of three fields: the source language word index sequence,
target language word index sequence and next word index sequence.
NOTE:
The original like for training data is:
http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz
paddle.dataset.wmt16 provides a tokenized version of the original dataset by
using moses's tokenization script:
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
Args:
src_dict_size(int): Size of the source language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
trg_dict_size(int): Size of the target language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
src_lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
Returns:
callable: The train reader.
"""
assert (src_lang in ["en", "de"], ("An error language type. Only support: "
"en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)
return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
"wmt16.tar.gz"),
file_name="wmt16/train",
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
src_lang=src_lang)
def test(src_dict_size, trg_dict_size, src_lang="en"):
"""
WMT16 test set reader.
This function returns the reader for test data. Each sample the reader
returns is made up of three fields: the source language word index sequence,
target language word index sequence and next word index sequence.
NOTE:
The original like for test data is:
http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz
paddle.dataset.wmt16 provides a tokenized version of the original dataset by
using moses's tokenization script:
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
Args:
src_dict_size(int): Size of the source language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
trg_dict_size(int): Size of the target language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
src_lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
Returns:
callable: The test reader.
"""
assert (src_lang in ["en", "de"],
("An error language type. "
"Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)
return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
"wmt16.tar.gz"),
file_name="wmt16/test",
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
src_lang=src_lang)
def validation(src_dict_size, trg_dict_size, src_lang="en"):
"""
WMT16 validation set reader.
This function returns the reader for validation data. Each sample the reader
returns is made up of three fields: the source language word index sequence,
target language word index sequence and next word index sequence.
NOTE:
The original like for validation data is:
http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz
paddle.dataset.wmt16 provides a tokenized version of the original dataset by
using moses's tokenization script:
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
Args:
src_dict_size(int): Size of the source language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
trg_dict_size(int): Size of the target language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
src_lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
Returns:
callable: The validation reader.
"""
assert (src_lang in ["en", "de"],
("An error language type. "
"Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)
return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
"wmt16.tar.gz"),
file_name="wmt16/val",
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
src_lang=src_lang)
def get_dict(lang, dict_size, reverse=False):
"""
return the word dictionary for the specified language.
Args:
lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
dict_size(int): Size of the specified language dictionary.
reverse(bool): If reverse is set to False, the returned python
dictionary will use word as key and use index as value.
If reverse is set to True, the returned python
dictionary will use index as key and word as value.
Returns:
dict: The word dictionary for the specific language.
"""
if lang == "en": dict_size = min(dict_size, TOTAL_EN_WORDS)
else: dict_size = min(dict_size, TOTAL_DE_WORDS)
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
assert (os.path.exists(dict_path), "Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation "
"first to build the dictionary.")
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
return __load_dict(tar_file, dict_size, lang, reverse)
def fetch():
"""download the entire dataset.
"""
paddle.v4.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
"wmt16.tar.gz")
def convert(path, src_dict_size, trg_dict_size, src_lang):
"""Converts dataset to recordio format.
"""
paddle.v2.dataset.common.convert(
path,
train(
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
src_lang=src_lang),
1000,
"wmt16_train")
paddle.v2.dataset.common.convert(
path,
test(
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
src_lang=src_lang),
1000,
"wmt16_test")
paddle.v2.dataset.common.convert(
path,
validation(
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
src_lang=src_lang),
1000,
"wmt16_validation")
...@@ -86,7 +86,9 @@ def __bootstrap__(): ...@@ -86,7 +86,9 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads) os.environ['OMP_NUM_THREADS'] = str(num_threads)
read_env_flags = ['use_pinned_memory', 'check_nan_inf'] read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'do_memory_benchmark'
]
if core.is_compile_gpu(): if core.is_compile_gpu():
read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync'] read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync']
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import functools import functools
import layers import layers
import framework
from . import core from . import core
__all__ = [ __all__ = [
...@@ -66,7 +67,7 @@ def error_clip_callback(block, context): ...@@ -66,7 +67,7 @@ def error_clip_callback(block, context):
class BaseGradientClipAttr(object): class BaseGradientClipAttr(object):
def process_context(self, context, p_g): def process_context(self, context, param, grad):
raise NotImplementedError() raise NotImplementedError()
def create_operators(self, param, grad): def create_operators(self, param, grad):
...@@ -74,7 +75,7 @@ class BaseGradientClipAttr(object): ...@@ -74,7 +75,7 @@ class BaseGradientClipAttr(object):
class NullGradientClipAttr(BaseGradientClipAttr): class NullGradientClipAttr(BaseGradientClipAttr):
def process_context(self, context, p_g): def process_context(self, context, param, grad):
pass pass
def create_operators(self, param, grad): def create_operators(self, param, grad):
...@@ -91,7 +92,7 @@ class GradientClipByValue(BaseGradientClipAttr): ...@@ -91,7 +92,7 @@ class GradientClipByValue(BaseGradientClipAttr):
self.max = max self.max = max
self.min = min self.min = min
def process_context(self, context, p_g): def process_context(self, context, param, grad):
pass pass
def create_operators(self, param, grad): def create_operators(self, param, grad):
...@@ -99,19 +100,93 @@ class GradientClipByValue(BaseGradientClipAttr): ...@@ -99,19 +100,93 @@ class GradientClipByValue(BaseGradientClipAttr):
return param, new_grad return param, new_grad
class GradientClipByNorm(BaseGradientClipAttr):
def __init__(self, clip_norm):
self.clip_norm = clip_norm
def process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
return param, new_grad
class GradientClipByGlobalNorm(BaseGradientClipAttr):
def __init__(self, clip_norm, group_name="default_group"):
if not isinstance(group_name, basestring):
raise TypeError("'group_name' must be a basestring.")
self.clip_norm = clip_norm
self.group_name = group_name
def process_context(self, context, param, grad):
if self.group_name not in context:
context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = layers.fill_constant(
shape=[1], dtype="float32", value=self.clip_norm)
else:
if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError(
"All parameters' 'clip_norm' of a same group should be the same"
)
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
context[self.group_name].append(local_norm_var)
self.context = context
def create_operators(self, param, grad):
group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
layers.sqrt(x=group_norm_var, out=group_norm_var)
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(
x=clip_var,
y=layers.elementwise_max(
x=clip_var, y=group_norm_var))
assert group_scale_var.shape == (1L, )
self.context[group_scale_name] = group_scale_var
new_grad = layers.elementwise_mul(
x=grad, y=self.context[group_scale_name])
return param, new_grad
def gradient_clip_by_global_norm(clip_norm,
param_list=None,
group_name="default_group",
program=None):
if program is None:
program = framework.default_main_program()
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, basestring) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError(
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
for param in param_list:
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
group_name)
def append_gradient_clip_ops(param_grad): def append_gradient_clip_ops(param_grad):
context = dict() context = dict()
create_op_callbacks = [] create_op_callbacks = []
for p, g in param_grad: for p, g in param_grad:
clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
if clip_attr is None: if clip_attr is None:
clip_attr = NullGradientClipAttr() clip_attr = NullGradientClipAttr()
if not isinstance(clip_attr, BaseGradientClipAttr): if not isinstance(clip_attr, BaseGradientClipAttr):
raise TypeError( raise TypeError(
"clip attribute should be an instance of BaseGradientClippingAttr" "clip attribute should be an instance of BaseGradientClipAttr")
)
clip_attr.process_context(context=context, p_g=param_grad) clip_attr.process_context(context=context, param=p, grad=g)
create_op_callbacks.append( create_op_callbacks.append(
functools.partial( functools.partial(
clip_attr.create_operators, param=p, grad=g)) clip_attr.create_operators, param=p, grad=g))
......
...@@ -780,7 +780,7 @@ class Block(object): ...@@ -780,7 +780,7 @@ class Block(object):
trainable=p.trainable, trainable=p.trainable,
optimize_attr=p.optimize_attr, optimize_attr=p.optimize_attr,
regularizer=p.regularizer, regularizer=p.regularizer,
clip_attr=p.clip_attr, gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip, error_clip=p.error_clip,
name=v.name) name=v.name)
self.vars[new_p.name] = new_p self.vars[new_p.name] = new_p
...@@ -948,7 +948,7 @@ class Parameter(Variable): ...@@ -948,7 +948,7 @@ class Parameter(Variable):
self.regularizer = kwargs.get('regularizer', None) self.regularizer = kwargs.get('regularizer', None)
self.clip_attr = kwargs.get('clip_attr', None) self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
# program is a global instance. # program is a global instance.
......
...@@ -11,22 +11,41 @@ ...@@ -11,22 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
from ..layer_helper import LayerHelper, unique_name from layer_function_generator import autodoc
from ..framework import Program, Variable, Operator
from .. import core
from tensor import assign, fill_constant from tensor import assign, fill_constant
import contextlib from .. import core
from ..registry import autodoc from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name
__all__ = [ __all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'split_lod_tensor',
'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While', 'merge_lod_tensor',
'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array', 'BlockGuard',
'array_to_lod_tensor', 'increment', 'array_write', 'create_array', 'BlockGuardWithCompletion',
'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse', 'StaticRNNMemoryLink',
'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank', 'WhileGuard',
'ParallelDo', 'Print' 'While',
'lod_rank_table',
'max_sequence_len',
'topk',
'lod_tensor_to_array',
'array_to_lod_tensor',
'increment',
'array_write',
'create_array',
'less_than',
'array_read',
'shrink_memory',
'array_length',
'IfElse',
'DynamicRNN',
'ConditionalBlock',
'StaticRNN',
'reorder_lod_tensor_by_rank',
'ParallelDo',
'Print',
] ]
...@@ -1458,7 +1477,7 @@ class DynamicRNN(object): ...@@ -1458,7 +1477,7 @@ class DynamicRNN(object):
method)) method))
@autodoc @autodoc()
def reorder_lod_tensor_by_rank(x, rank_table): def reorder_lod_tensor_by_rank(x, rank_table):
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals()) helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
helper.is_instance('x', Variable) helper.is_instance('x', Variable)
......
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
All util layers. All util layers.
""" """
from ..layer_helper import LayerHelper from layer_function_generator import autodoc
from ..framework import unique_name from ..framework import unique_name
from ..registry import autodoc from ..layer_helper import LayerHelper
__all__ = ['get_places'] __all__ = ['get_places']
@autodoc @autodoc()
def get_places(device_count=None, device_type=None): def get_places(device_count=None, device_type=None):
helper = LayerHelper('get_places', **locals()) helper = LayerHelper('get_places', **locals())
out_places = helper.create_variable(name=unique_name(helper.name + ".out")) out_places = helper.create_variable(name=unique_name(helper.name + ".out"))
......
...@@ -13,17 +13,19 @@ ...@@ -13,17 +13,19 @@
# limitations under the License. # limitations under the License.
import re import re
import cStringIO import cStringIO
import warnings
import functools import functools
import inspect import warnings
from .. import proto
import proto.framework_pb2 as framework_pb2 framework_pb2 = proto.framework_pb2
from framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name from ..framework import OpProtoHolder, Variable
from ..layer_helper import LayerHelper
__all__ = [ __all__ = [
'deprecated', 'deprecated',
'register_layer', 'generate_layer_fn',
'autodoc', 'autodoc',
] ]
...@@ -96,7 +98,7 @@ def _generate_doc_string_(op_proto): ...@@ -96,7 +98,7 @@ def _generate_doc_string_(op_proto):
return buf.getvalue() return buf.getvalue()
def register_layer(op_type): def generate_layer_fn(op_type):
"""Register the Python layer for an Operator. """Register the Python layer for an Operator.
Args: Args:
...@@ -207,7 +209,10 @@ def deprecated(func_or_class): ...@@ -207,7 +209,10 @@ def deprecated(func_or_class):
return func_wrapper return func_wrapper
def autodoc(func): def autodoc(comment=""):
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto( def __impl__(func):
func.__name__)) func.__doc__ = _generate_doc_string_(OpProtoHolder.instance(
).get_op_proto(func.__name__)) + comment
return func return func
return __impl__
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from layer_function_generator import generate_layer_fn
from ..registry import register_layer
__activations__ = [ __activations__ = [
'sigmoid', 'sigmoid',
...@@ -46,21 +45,11 @@ __activations__ = [ ...@@ -46,21 +45,11 @@ __activations__ = [
] ]
__all__ = [ __all__ = [
'mean', 'mean', 'mul', 'reshape', 'scale', 'transpose',
'mul', 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
'reshape', 'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min',
'scale', 'clip', 'clip_by_norm', 'sequence_softmax'
'transpose',
'sigmoid_cross_entropy_with_logits',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'clip',
'sequence_softmax',
] + __activations__ ] + __activations__
for _OP in set(__all__): for _OP in set(__all__):
globals()[_OP] = register_layer(_OP) globals()[_OP] = generate_layer_fn(_OP)
...@@ -25,13 +25,13 @@ class ParamAttr(object): ...@@ -25,13 +25,13 @@ class ParamAttr(object):
learning_rate=1.0, learning_rate=1.0,
regularizer=None, regularizer=None,
trainable=True, trainable=True,
clip=None): gradient_clip=None):
self.name = name self.name = name
self.initializer = initializer self.initializer = initializer
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.regularizer = regularizer self.regularizer = regularizer
self.trainable = trainable self.trainable = trainable
self.clip = clip self.gradient_clip = gradient_clip
def set_default_initializer(self, initializer): def set_default_initializer(self, initializer):
if initializer is None: if initializer is None:
...@@ -77,7 +77,7 @@ class ParamAttr(object): ...@@ -77,7 +77,7 @@ class ParamAttr(object):
}, },
'regularizer': self.regularizer, 'regularizer': self.regularizer,
'trainable': self.trainable, 'trainable': self.trainable,
'clip_attr': self.clip 'gradient_clip_attr': self.gradient_clip
} }
if with_initializer: if with_initializer:
kwargs['initializer'] = self.initializer kwargs['initializer'] = self.initializer
......
...@@ -6,3 +6,4 @@ endforeach() ...@@ -6,3 +6,4 @@ endforeach()
add_subdirectory(book) add_subdirectory(book)
add_subdirectory(book_distribute) add_subdirectory(book_distribute)
add_subdirectory(book_memory_optimization)
...@@ -27,7 +27,7 @@ hidden1 = fluid.layers.fc(input=image, ...@@ -27,7 +27,7 @@ hidden1 = fluid.layers.fc(input=image,
act='relu', act='relu',
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
regularizer=regularizer, regularizer=regularizer,
clip=fluid.clip.ClipByValue(10))) gradient_clip=fluid.clip.ClipByValue(10)))
hidden2 = fluid.layers.fc(input=hidden1, hidden2 = fluid.layers.fc(input=hidden1,
size=64, size=64,
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_memopt_image_classification_train)
py_test(test_memopt_image_classification_train_resnet SRCS test_memopt_image_classification_train.py ARGS resnet)
py_test(test_memopt_image_classification_train_vgg SRCS test_memopt_image_classification_train.py ARGS vgg)
# default test
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost)
# memopt_program = fluid.default_main_program()
memopt_program = fluid.memory_optimize(fluid.default_main_program())
BATCH_SIZE = 200
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
PASS_NUM = 100
for pass_id in range(PASS_NUM):
fluid.io.save_persistables(exe, "./fit_a_line.model/")
fluid.io.load_persistables(exe, "./fit_a_line.model/")
for data in train_reader():
avg_loss_value, = exe.run(memopt_program,
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_loss_value[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.
exit(1)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
def resnet_cifar10(input, depth=32):
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=tmp, act=act)
def shortcut(input, ch_in, ch_out, stride):
if ch_in != ch_out:
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
else:
return input
def basicblock(input, ch_in, ch_out, stride):
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None)
short = shortcut(input, ch_in, ch_out, stride)
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
tmp = block_func(input, ch_in, ch_out, stride)
for i in range(1, count):
tmp = block_func(tmp, ch_out, ch_out, 1)
return tmp
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool
def vgg16_bn_drop(input):
def conv_block(input, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max')
conv1 = conv_block(input, 64, 2, [0.3, 0])
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0])
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = fluid.layers.fc(input=drop, size=512, act=None)
bn = fluid.layers.batch_norm(input=fc1, act='relu')
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
return fc2
classdim = 10
data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net_type = "vgg"
if len(sys.argv) >= 2:
net_type = sys.argv[1]
if net_type == "vgg":
print("train vgg net")
net = vgg16_bn_drop(images)
elif net_type == "resnet":
print("train resnet")
net = resnet_cifar10(images, 32)
else:
raise ValueError("%s network is not supported" % net_type)
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
# memopt_program = fluid.default_main_program()
memopt_program = fluid.memory_optimize(fluid.default_main_program())
BATCH_SIZE = 128
PASS_NUM = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
loss, acc = exe.run(memopt_program,
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc))
# this model is slow, so if we can train two mini batch, we think it works properly.
exit(0)
exit(1)
...@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback): ...@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
BATCH_SIZE = 128
CLIP = 1
prog = fluid.framework.Program()
with fluid.program_guard(main_program=prog):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
prog_clip = prog.clone()
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip):
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
grad_list = [elem[1] for elem in p_g]
grad_clip_list = [elem[1] for elem in p_g_clip]
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(fluid.default_startup_program())
count = 0
for data in train_reader():
count += 1
if count > 5:
break
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
global_norm = 0
for v in out[1:]:
global_norm += np.sum(np.power(v, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in out_clip[1:]:
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
if not np.isclose(
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3):
exit(1)
exit(0)
...@@ -11,26 +11,21 @@ ...@@ -11,26 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
import warnings
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import paddle.v2.fluid.framework as framework import numpy as np
import paddle.v2.fluid.layers as layers import decorators
import paddle.v2.fluid.registry as registry
class TestRegistry(unittest.TestCase): class TestRegistry(unittest.TestCase):
@decorators.prog_scope()
def test_registry_layer(self): def test_registry_layer(self):
self.layer_type = "mean"
program = framework.Program()
x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32') x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32')
output = layers.mean(x) output = fluid.layers.mean(x=x)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
X = np.random.random((10, 10)).astype("float32") X = np.random.random((10, 10)).astype("float32")
mean_out = exe.run(program, feed={"X": X}, fetch_list=[output]) mean_out = exe.run(feed={"X": X}, fetch_list=[output])
self.assertAlmostEqual(np.mean(X), mean_out) self.assertAlmostEqual(np.mean(X), mean_out[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册