diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index c0418c9266e257bd7567861543e557f354451b17..1382bfca19a674a404916a5c709276ce41219d2f 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/platform/place.h" +DECLARE_bool(do_memory_benchmark); DEFINE_bool(check_nan_inf, false, "Checking whether operator produce NAN/INF or not. It will be " "extremely slow so please use this flag wisely."); @@ -117,6 +118,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(3) << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); + if (FLAGS_do_memory_benchmark) { + VLOG(2) << "Memory used after operator " + op->Type() + " running: " + << memory::memory_usage(place_); + } if (FLAGS_check_nan_inf) { for (auto& vname : op->OutputVars(true)) { auto* var = local_scope->FindVar(vname); @@ -130,6 +135,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, if (create_vars && create_local_scope) { scope->DeleteScope(local_scope); } + if (FLAGS_do_memory_benchmark) { + VLOG(2) << "-------------------------------------------------------"; + VLOG(2) << "Memory used after deleting local scope: " + << memory::memory_usage(place_); + VLOG(2) << "-------------------------------------------------------"; + } } } // namespace framework diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 2bd0ac8f5a9eb6439a4196dd9c61e13797c1a8e3..a67ff910093d93060d07d849f6e968e5f4ce21cd 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -20,6 +20,10 @@ limitations under the License. */ #include "paddle/framework/threadpool.h" #include "paddle/string/printf.h" +DEFINE_bool(do_memory_benchmark, false, + "Doing memory benchmark. It will make deleting scope synchronized, " + "and add some memory usage logs"); + namespace paddle { namespace framework { @@ -88,8 +92,12 @@ void Scope::DeleteScope(Scope* scope) { auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); this->kids_.erase(it); - // Make delete async. - Async([scope] { delete scope; }); + // When making memory benchmark on Fluid, we have to delete scope sync. + if (FLAGS_do_memory_benchmark) { + delete scope; + } else { + Async([scope] { delete scope; }); + } } void Scope::Rename(const std::string& origin_name, diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index 331bc7672ec0d39a7317c39f1d14e8dcadea471a..337b9ba7bc0fc4e4bb80ee7b248d934f111379d5 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -65,14 +65,19 @@ bool PriorBoxLayer::init(const LayerMap& layerMap, std::copy(pbConf.aspect_ratio().begin(), pbConf.aspect_ratio().end(), std::back_inserter(tmp)); - // flip - int inputRatioLength = tmp.size(); - for (int index = 0; index < inputRatioLength; index++) { - aspectRatio_.push_back(tmp[index]); - aspectRatio_.push_back(1 / tmp[index]); + + if (maxSize_.size() > 0) CHECK_EQ(minSize_.size(), maxSize_.size()); + + // flip aspect ratios + for (int index = 0; index < tmp.size(); index++) { + real ar = tmp[index]; + if (fabs(ar - 1.) < 1e-6) continue; + aspectRatio_.push_back(ar); + aspectRatio_.push_back(1. / ar); } - numPriors_ = aspectRatio_.size(); - if (maxSize_.size() > 0) numPriors_++; + + numPriors_ = aspectRatio_.size() * minSize_.size() + maxSize_.size(); + return true; } @@ -99,50 +104,39 @@ void PriorBoxLayer::forward(PassType passType) { for (int w = 0; w < layerWidth; ++w) { real centerX = (w + 0.5) * stepW; real centerY = (h + 0.5) * stepH; - real minSize = 0; for (size_t s = 0; s < minSize_.size(); s++) { - // first prior. - minSize = minSize_[s]; + real minSize = minSize_[s]; real boxWidth = minSize; real boxHeight = minSize; - // xmin, ymin, xmax, ymax. - tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; - tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; - tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; - tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; - // set the variance. - for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; + + // priors with different aspect ratios + for (size_t r = 0; r < aspectRatio_.size(); r++) { + real ar = aspectRatio_[r]; + boxWidth = minSize * sqrt(ar); + boxHeight = minSize / sqrt(ar); + tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; + tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; + // set the variance. + for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; + } if (maxSize_.size() > 0) { - CHECK_EQ(minSize_.size(), maxSize_.size()); - // second prior. - for (size_t s = 0; s < maxSize_.size(); s++) { - real maxSize = maxSize_[s]; - boxWidth = boxHeight = sqrt(minSize * maxSize); - tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; - tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; - tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; - tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; - // set the variance. - for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; - } + // square prior with size sqrt(minSize * maxSize) + real maxSize = maxSize_[s]; + boxWidth = boxHeight = sqrt(minSize * maxSize); + tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; + tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; + tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; + // set the variance. + for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; } } - // rest of priors. - for (size_t r = 0; r < aspectRatio_.size(); r++) { - real ar = aspectRatio_[r]; - if (fabs(ar - 1.) < 1e-6) continue; - real boxWidth = minSize * sqrt(ar); - real boxHeight = minSize / sqrt(ar); - tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; - tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; - tmpPtr[idx++] = (centerX + boxWidth / 2.) / imageWidth; - tmpPtr[idx++] = (centerY + boxHeight / 2.) / imageHeight; - // set the variance. - for (int t = 0; t < 4; t++) tmpPtr[idx++] = variance_[t]; - } } } + // clip the prior's coordidate such that it is within [0, 1] for (int d = 0; d < dim * 2; ++d) if ((d % 8) < 4) diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index daa2c193b48fe216ff284169a3dce1b4cd40a791..930c295a9cb31238954efeb87ff5ac2d3ca7bdc6 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is calculated by %s )DOC", comment.type, comment.equation)); + AddAttr("axis", + "(int, default -1). The start dimension index " + "for broadcasting Y onto X.") + .SetDefault(-1) + .EqualGreaterThan(-1); } }; @@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); -REGISTER_LOGICAL_OP(greater_than, "Out = X > Y"); -REGISTER_LOGICAL_KERNEL(greater_than, CPU, - paddle::operators::GreaterThanFunctor); -REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y"); -REGISTER_LOGICAL_KERNEL(greater_equal, CPU, - paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.cu b/paddle/operators/compare_op.cu index 26049271befd1fe57001659d1a406e73de0004a7..f625824dbc99d603f1e92700b4ad3d7fa25b471d 100644 --- a/paddle/operators/compare_op.cu +++ b/paddle/operators/compare_op.cu @@ -16,8 +16,4 @@ limitations under the License. */ REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); -REGISTER_LOGICAL_KERNEL(greater_than, CUDA, - paddle::operators::GreaterThanFunctor); -REGISTER_LOGICAL_KERNEL(greater_equal, CUDA, - paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h index 567e89c0a727ad0cdd2add8ec8b2a42c86a58007..9c655d6c0d8e5fe04ee6d85f7e9d9da68105230c 100644 --- a/paddle/operators/compare_op.h +++ b/paddle/operators/compare_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "paddle/framework/op_registry.h" +#include "paddle/operators/elementwise_op_function.h" #include "paddle/platform/transform.h" namespace paddle { @@ -33,18 +34,6 @@ struct LessEqualFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } }; -template -struct GreaterThanFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } -}; - -template -struct GreaterEqualFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } -}; - template struct EqualFunctor { using ELEM_TYPE = T; @@ -65,14 +54,7 @@ class CompareOpKernel public: void Compute(const framework::ExecutionContext& context) const override { using T = typename Functor::ELEM_TYPE; - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* out = context.Output("Out"); - Functor binary_func; - platform::Transform trans; - trans(context.template device_context(), x->data(), - x->data() + x->numel(), y->data(), - out->mutable_data(context.GetPlace()), binary_func); + ElementwiseComputeEx(context); } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index db5d30c1af286913f8decd7ab74058fd732ead65..d749b8e8757d0d433be05876779ccc22b95ca80b 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -176,14 +176,15 @@ class MidWiseTransformIterator }; #endif -template +template class TransformFunctor { public: TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z, const DeviceContext& ctx, Functor func) : x_(x->data()), y_(y->data()), - z_(z->mutable_data(ctx.GetPlace())), + z_(z->mutable_data(ctx.GetPlace())), nx_(x->numel()), ctx_(ctx), func_(func) {} @@ -208,7 +209,7 @@ class TransformFunctor { private: const T* x_; const T* y_; - T* z_; + OutType* z_; int64_t nx_; const DeviceContext& ctx_; Functor func_; @@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { } } -template +template void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { using Tensor = framework::Tensor; auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor functor( + z->mutable_data(ctx.GetPlace()); + TransformFunctor functor( x, y, z, ctx.template device_context(), Functor()); auto x_dims = x->dims(); diff --git a/paddle/operators/math/sampler.cc b/paddle/operators/math/sampler.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f1cbfe31ac68499a51eda600b38b879f7ca055f --- /dev/null +++ b/paddle/operators/math/sampler.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "sampler.h" + +namespace paddle { +namespace random { + +Sampler::~Sampler() {} + +UniformSampler::UniformSampler(int64 range) + : Sampler(range), inv_range_(1.0 / range) { + random_engine_ = std::make_shared(seed_); + dist_ = std::make_shared>(0, range); +} + +UniformSampler::UniformSampler(int64 range, unsigned int seed) + : Sampler(range, seed), inv_range_(1.0 / range) { + random_engine_ = std::make_shared(seed_); + dist_ = std::make_shared>(0, range); +} + +int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); } + +float UniformSampler::Probability(int64 value) const { return inv_range_; } + +LogUniformSampler::LogUniformSampler(int64 range) + : Sampler(range), log_range_(log(range + 1)) { + random_engine_ = std::make_shared(seed_); + dist_ = std::make_shared>(0, 1); +} + +LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed) + : Sampler(range, seed), log_range_(log(range + 1)) { + random_engine_ = std::make_shared(seed_); + dist_ = std::make_shared>(0, 1); +} +int64 LogUniformSampler::Sample() const { + // Got Log Uniform distribution from uniform distribution by + // inverse_transform_sampling method + // More details: + // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ + const int64 value = + static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; + // Mathematically, value should be <= range_, but might not be due to some + // floating point roundoff, so we mod by range_. + return value % range_; +} + +float LogUniformSampler::Probability(int64 value) const { + // Given f(x) = 1/[(x+1) * log_range_] + // The value's probability is integral of f(x) from value to (value + 1) + // More details: + // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler + return (log((value + 2.0) / (value + 1.0))) / log_range_; +} + +} // namespace random +} // namespace paddle diff --git a/paddle/operators/math/sampler.h b/paddle/operators/math/sampler.h new file mode 100644 index 0000000000000000000000000000000000000000..8f82089e7bd9e0ae6282459b650c225d6765faee --- /dev/null +++ b/paddle/operators/math/sampler.h @@ -0,0 +1,100 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +typedef long int64; +namespace paddle { +namespace operators { +namespace math { + +// TODO(wanghaoshuang): Support for GPU + +/** +* Sample integers from [0, range). +*/ +class Sampler { + public: + explicit Sampler(int64 range) : range_(range) { + PADDLE_ENFORCE_GT(range, 0); + std::random_device r; + seed_ = r(); + } + explicit Sampler(int64 range, unsigned int seed) + : range_(range), seed_(seed) { + PADDLE_ENFORCE_GT(range, 0); + } + virtual ~Sampler(); + // Sample a single value + virtual int64 Sample() const = 0; + // The probability that a single call to Sample() returns the given value. + virtual float Probability(int64 value) const = 0; + + int64 range() { return range_; }; + + protected: + const int64 range_; + unsigned int seed_; +}; + +/** + * Sample integers from [0, range). + * And the distribution function is: + * P(x) = 1 / range + */ +class UniformSampler : public Sampler { + public: + explicit UniformSampler(int64 range); + + explicit UniformSampler(int64 range, unsigned int seed); + + ~UniformSampler() override {} + + int64 Sample() const override; + + float Probability(int64 value) const override; + + private: + const float inv_range_; + std::shared_ptr random_engine_; + std::shared_ptr> dist_; +}; + +/** + * Sample integers from [0, range). + * And the distribution function is: + * P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1)) + */ +class LogUniformSampler : public Sampler { + public: + explicit LogUniformSampler(int64 range); + + explicit LogUniformSampler(int64 range, unsigned int seed); + + ~LogUniformSampler() override {} + + int64 Sample() const override; + + float Probability(int64 value) const override; + + private: + const float log_range_; + std::shared_ptr random_engine_; + std::shared_ptr> dist_; +}; + +} // math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 90830515c1e8e6f5260cfca631e02a3a52cedbe5..c1acbecd9c313b02d6d33d2d04fd33fc1a8b026e 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -24,11 +24,23 @@ import conll05 import uci_housing import sentiment import wmt14 +import wmt16 import mq2007 import flowers import voc2012 __all__ = [ - 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' + 'mnist', + 'imikolov', + 'imdb', + 'cifar', + 'movielens', + 'conll05', + 'sentiment' + 'uci_housing', + 'wmt14', + 'wmt16', + 'mq2007', + 'flowers', + 'voc2012', ] diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index fab8a68b0beee8b813bee2a05047e2da526a9c9b..9aba35a6481e3ad3ab37c8d4de0f998c9f0a1f07 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -25,8 +25,12 @@ import glob import cPickle as pickle __all__ = [ - 'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader', - 'convert' + 'DATA_HOME', + 'download', + 'md5file', + 'split', + 'cluster_files_reader', + 'convert', ] DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -58,12 +62,15 @@ def md5file(fname): return hash_md5.hexdigest() -def download(url, module_name, md5sum): +def download(url, module_name, md5sum, save_name=None): dirname = os.path.join(DATA_HOME, module_name) if not os.path.exists(dirname): os.makedirs(dirname) - filename = os.path.join(dirname, url.split('/')[-1]) + filename = os.path.join(dirname, + url.split('/')[-1] + if save_name is None else save_name) + retry = 0 retry_limit = 3 while not (os.path.exists(filename) and md5file(filename) == md5sum): @@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix): Convert data from reader to recordio format files. :param output_path: directory in which output files will be saved. - :param reader: a data reader, from which the convert program will read data instances. + :param reader: a data reader, from which the convert program will read + data instances. :param name_prefix: the name prefix of generated files. - :param max_lines_to_shuffle: the max lines numbers to shuffle before writing. + :param max_lines_to_shuffle: the max lines numbers to shuffle before + writing. """ assert line_count >= 1 diff --git a/python/paddle/v2/dataset/tests/wmt16_test.py b/python/paddle/v2/dataset/tests/wmt16_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cef6c3216e7de8d9785a063976e63f88d90b24df --- /dev/null +++ b/python/paddle/v2/dataset/tests/wmt16_test.py @@ -0,0 +1,66 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.wmt16 +import unittest + + +class TestWMT16(unittest.TestCase): + def checkout_one_sample(self, sample): + # train data has 3 field: source language word indices, + # target language word indices, and target next word indices. + self.assertEqual(len(sample), 3) + + # test start mark and end mark in source word indices. + self.assertEqual(sample[0][0], 0) + self.assertEqual(sample[0][-1], 1) + + # test start mask in target word indices + self.assertEqual(sample[1][0], 0) + + # test en mask in target next word indices + self.assertEqual(sample[2][-1], 1) + + def test_train(self): + for idx, sample in enumerate( + paddle.v2.dataset.wmt16.train( + src_dict_size=100000, trg_dict_size=100000)()): + if idx >= 10: break + self.checkout_one_sample(sample) + + def test_test(self): + for idx, sample in enumerate( + paddle.v2.dataset.wmt16.test( + src_dict_size=1000, trg_dict_size=1000)()): + if idx >= 10: break + self.checkout_one_sample(sample) + + def test_val(self): + for idx, sample in enumerate( + paddle.v2.dataset.wmt16.validation( + src_dict_size=1000, trg_dict_size=1000)()): + if idx >= 10: break + self.checkout_one_sample(sample) + + def test_get_dict(self): + dict_size = 1000 + word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True) + self.assertEqual(len(word_dict), dict_size) + self.assertEqual(word_dict[0], "") + self.assertEqual(word_dict[1], "") + self.assertEqual(word_dict[2], "") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 95a35d97ce9d9503153974cc167ee60829244d5f..5104e29051e4480f3a7eb18421f1b519841b009b 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -25,12 +25,20 @@ import gzip import paddle.v2.dataset.common from paddle.v2.parameters import Parameters -__all__ = ['train', 'test', 'build_dict', 'convert'] - -URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' +__all__ = [ + 'train', + 'test', + 'get_dict', + 'convert', +] + +URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/' + 'cslm_joint_paper/data/dev+test.tgz') MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' -# this is a small set of data for test. The original data is too large and will be add later. -URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' +# this is a small set of data for test. The original data is too large and +# will be add later. +URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/' + 'wmt_shrinked_data/wmt14.tgz') MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' # BLEU of this trained model is 26.92 URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' @@ -42,8 +50,8 @@ UNK = "" UNK_IDX = 2 -def __read_to_dict__(tar_file, dict_size): - def __to_dict__(fd, size): +def __read_to_dict(tar_file, dict_size): + def __to_dict(fd, size): out_dict = dict() for line_count, line in enumerate(fd): if line_count < size: @@ -58,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size): if each_item.name.endswith("src.dict") ] assert len(names) == 1 - src_dict = __to_dict__(f.extractfile(names[0]), dict_size) + src_dict = __to_dict(f.extractfile(names[0]), dict_size) names = [ each_item.name for each_item in f if each_item.name.endswith("trg.dict") ] assert len(names) == 1 - trg_dict = __to_dict__(f.extractfile(names[0]), dict_size) + trg_dict = __to_dict(f.extractfile(names[0]), dict_size) return src_dict, trg_dict def reader_creator(tar_file, file_name, dict_size): def reader(): - src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) + src_dict, trg_dict = __read_to_dict(tar_file, dict_size) with tarfile.open(tar_file, mode='r') as f: names = [ each_item.name for each_item in f @@ -152,7 +160,7 @@ def get_dict(dict_size, reverse=True): # if reverse = False, return dict = {'a':'001', 'b':'002', ...} # else reverse = true, return dict = {'001':'a', '002':'b', ...} tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) - src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) + src_dict, trg_dict = __read_to_dict(tar_file, dict_size) if reverse: src_dict = {v: k for k, v in src_dict.items()} trg_dict = {v: k for k, v in trg_dict.items()} diff --git a/python/paddle/v2/dataset/wmt16.py b/python/paddle/v2/dataset/wmt16.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc28a2da99052308471931122946d0d96b54da5 --- /dev/null +++ b/python/paddle/v2/dataset/wmt16.py @@ -0,0 +1,348 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +ACL2016 Multimodal Machine Translation. Please see this website for more +details: http://www.statmt.org/wmt16/multimodal-task.html#task1 + +If you use the dataset created for your task, please cite the following paper: +Multi30K: Multilingual English-German Image Descriptions. + +@article{elliott-EtAl:2016:VL16, + author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.}, + title = {Multi30K: Multilingual English-German Image Descriptions}, + booktitle = {Proceedings of the 6th Workshop on Vision and Language}, + year = {2016}, + pages = {70--74}, + year = 2016 +} +""" + +import os +import tarfile +import gzip +from collections import defaultdict + +import paddle.v2.dataset.common + +__all__ = [ + "train", + "test", + "validation", + "convert", + "fetch", + "get_dict", +] + +DATA_URL = ("http://cloud.dlnel.org/filepub/" + "?uuid=46a0808e-ddd8-427c-bacd-0dbc6d045fed") +DATA_MD5 = "0c38be43600334966403524a40dcd81e" + +TOTAL_EN_WORDS = 11250 +TOTAL_DE_WORDS = 19220 + +START_MARK = "" +END_MARK = "" +UNK_MARK = "" + + +def __build_dict(tar_file, dict_size, save_path, lang): + word_dict = defaultdict(int) + with tarfile.open(tar_file, mode="r") as f: + for line in f.extractfile("wmt16/train"): + line_split = line.strip().split("\t") + if len(line_split) != 2: continue + sen = line_split[0] if lang == "en" else line_split[1] + for w in sen.split(): + word_dict[w] += 1 + + with open(save_path, "w") as fout: + fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) + for idx, word in enumerate( + sorted( + word_dict.iteritems(), key=lambda x: x[1], reverse=True)): + if idx + 3 == dict_size: break + fout.write("%s\n" % (word[0])) + + +def __load_dict(tar_file, dict_size, lang, reverse=False): + dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, + "wmt16/%s_%d.dict" % (lang, dict_size)) + if not os.path.exists(dict_path) or ( + len(open(dict_path, "r").readlines()) != dict_size): + __build_dict(tar_file, dict_size, dict_path, lang) + + word_dict = {} + with open(dict_path, "r") as fdict: + for idx, line in enumerate(fdict): + if reverse: + word_dict[idx] = line.strip() + else: + word_dict[line.strip()] = idx + return word_dict + + +def __get_dict_size(src_dict_size, trg_dict_size, src_lang): + src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else + TOTAL_DE_WORDS)) + trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else + TOTAL_ENG_WORDS)) + return src_dict_size, trg_dict_size + + +def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): + def reader(): + src_dict = __load_dict(tar_file, src_dict_size, src_lang) + trg_dict = __load_dict(tar_file, trg_dict_size, + ("de" if src_lang == "en" else "en")) + + # the indice for start mark, end mark, and unk are the same in source + # language and target language. Here uses the source language + # dictionary to determine their indices. + start_id = src_dict[START_MARK] + end_id = src_dict[END_MARK] + unk_id = src_dict[UNK_MARK] + + src_col = 0 if src_lang == "en" else 1 + trg_col = 1 - src_col + + with tarfile.open(tar_file, mode="r") as f: + for line in f.extractfile(file_name): + line_split = line.strip().split("\t") + if len(line_split) != 2: + continue + src_words = line_split[src_col].split() + src_ids = [start_id] + [ + src_dict.get(w, unk_id) for w in src_words + ] + [end_id] + + trg_words = line_split[trg_col].split() + trg_ids = [trg_dict.get(w, unk_id) for w in trg_words] + + trg_ids_next = trg_ids + [end_id] + trg_ids = [start_id] + trg_ids + + yield src_ids, trg_ids, trg_ids_next + + return reader + + +def train(src_dict_size, trg_dict_size, src_lang="en"): + """ + WMT16 train set reader. + + This function returns the reader for train data. Each sample the reader + returns is made up of three fields: the source language word index sequence, + target language word index sequence and next word index sequence. + + + NOTE: + The original like for training data is: + http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz + + paddle.dataset.wmt16 provides a tokenized version of the original dataset by + using moses's tokenization script: + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl + + Args: + src_dict_size(int): Size of the source language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + trg_dict_size(int): Size of the target language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + src_lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + + Returns: + callable: The train reader. + """ + + assert (src_lang in ["en", "de"], ("An error language type. Only support: " + "en (for English); de(for Germany)")) + src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, + src_lang) + + return reader_creator( + tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz"), + file_name="wmt16/train", + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang) + + +def test(src_dict_size, trg_dict_size, src_lang="en"): + """ + WMT16 test set reader. + + This function returns the reader for test data. Each sample the reader + returns is made up of three fields: the source language word index sequence, + target language word index sequence and next word index sequence. + + NOTE: + The original like for test data is: + http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz + + paddle.dataset.wmt16 provides a tokenized version of the original dataset by + using moses's tokenization script: + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl + + Args: + src_dict_size(int): Size of the source language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + trg_dict_size(int): Size of the target language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + src_lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + + Returns: + callable: The test reader. + """ + + assert (src_lang in ["en", "de"], + ("An error language type. " + "Only support: en (for English); de(for Germany)")) + + src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, + src_lang) + + return reader_creator( + tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz"), + file_name="wmt16/test", + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang) + + +def validation(src_dict_size, trg_dict_size, src_lang="en"): + """ + WMT16 validation set reader. + + This function returns the reader for validation data. Each sample the reader + returns is made up of three fields: the source language word index sequence, + target language word index sequence and next word index sequence. + + NOTE: + The original like for validation data is: + http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz + + paddle.dataset.wmt16 provides a tokenized version of the original dataset by + using moses's tokenization script: + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl + + Args: + src_dict_size(int): Size of the source language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + trg_dict_size(int): Size of the target language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + src_lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + + Returns: + callable: The validation reader. + """ + assert (src_lang in ["en", "de"], + ("An error language type. " + "Only support: en (for English); de(for Germany)")) + src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, + src_lang) + + return reader_creator( + tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz"), + file_name="wmt16/val", + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang) + + +def get_dict(lang, dict_size, reverse=False): + """ + return the word dictionary for the specified language. + + Args: + lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + dict_size(int): Size of the specified language dictionary. + reverse(bool): If reverse is set to False, the returned python + dictionary will use word as key and use index as value. + If reverse is set to True, the returned python + dictionary will use index as key and word as value. + + Returns: + dict: The word dictionary for the specific language. + """ + + if lang == "en": dict_size = min(dict_size, TOTAL_EN_WORDS) + else: dict_size = min(dict_size, TOTAL_DE_WORDS) + + dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, + "wmt16/%s_%d.dict" % (lang, dict_size)) + assert (os.path.exists(dict_path), "Word dictionary does not exist. " + "Please invoke paddle.dataset.wmt16.train/test/validation " + "first to build the dictionary.") + tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz") + return __load_dict(tar_file, dict_size, lang, reverse) + + +def fetch(): + """download the entire dataset. + """ + paddle.v4.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz") + + +def convert(path, src_dict_size, trg_dict_size, src_lang): + """Converts dataset to recordio format. + """ + + paddle.v2.dataset.common.convert( + path, + train( + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang), + 1000, + "wmt16_train") + paddle.v2.dataset.common.convert( + path, + test( + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang), + 1000, + "wmt16_test") + paddle.v2.dataset.common.convert( + path, + validation( + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang), + 1000, + "wmt16_validation") diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 4ae4165ba4078374c44113371ec4cd84c95c40ee..1f041c74597637a7b74e9690a60b6cd8fdd21cf8 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -37,6 +37,7 @@ import clip from memory_optimization_transpiler import memory_optimize Tensor = LoDTensor + __all__ = framework.__all__ + executor.__all__ + [ 'io', 'initializer', @@ -85,7 +86,9 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) - read_env_flags = ['use_pinned_memory', 'check_nan_inf'] + read_env_flags = [ + 'use_pinned_memory', 'check_nan_inf', 'do_memory_benchmark' + ] if core.is_compile_gpu(): read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync'] core.init_gflags([sys.argv[0]] + @@ -94,4 +97,5 @@ def __bootstrap__(): core.init_devices() +layers.monkey_patch_variable() __bootstrap__() diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 5241f4843c0df0314eba6168f8f36335c3f19d0a..386df9823de9119287abf87569eab0b283ecc802 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -14,6 +14,7 @@ import functools import layers +import framework from . import core __all__ = [ @@ -66,7 +67,7 @@ def error_clip_callback(block, context): class BaseGradientClipAttr(object): - def process_context(self, context, p_g): + def process_context(self, context, param, grad): raise NotImplementedError() def create_operators(self, param, grad): @@ -74,7 +75,7 @@ class BaseGradientClipAttr(object): class NullGradientClipAttr(BaseGradientClipAttr): - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -91,7 +92,7 @@ class GradientClipByValue(BaseGradientClipAttr): self.max = max self.min = min - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -99,19 +100,93 @@ class GradientClipByValue(BaseGradientClipAttr): return param, new_grad +class GradientClipByNorm(BaseGradientClipAttr): + def __init__(self, clip_norm): + self.clip_norm = clip_norm + + def process_context(self, context, param, grad): + pass + + def create_operators(self, param, grad): + new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm) + return param, new_grad + + +class GradientClipByGlobalNorm(BaseGradientClipAttr): + def __init__(self, clip_norm, group_name="default_group"): + if not isinstance(group_name, basestring): + raise TypeError("'group_name' must be a basestring.") + + self.clip_norm = clip_norm + self.group_name = group_name + + def process_context(self, context, param, grad): + if self.group_name not in context: + context[self.group_name] = [] + context[self.group_name + "_clip_value"] = self.clip_norm + context[self.group_name + "_clip"] = layers.fill_constant( + shape=[1], dtype="float32", value=self.clip_norm) + else: + if not self.clip_norm == context[self.group_name + "_clip_value"]: + raise ValueError( + "All parameters' 'clip_norm' of a same group should be the same" + ) + + local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) + context[self.group_name].append(local_norm_var) + + self.context = context + + def create_operators(self, param, grad): + group_scale_name = self.group_name + "_scale" + if group_scale_name not in self.context: + group_norm_var = layers.sums(input=self.context[self.group_name]) + layers.sqrt(x=group_norm_var, out=group_norm_var) + clip_var = self.context[self.group_name + "_clip"] + group_scale_var = layers.elementwise_div( + x=clip_var, + y=layers.elementwise_max( + x=clip_var, y=group_norm_var)) + assert group_scale_var.shape == (1L, ) + self.context[group_scale_name] = group_scale_var + + new_grad = layers.elementwise_mul( + x=grad, y=self.context[group_scale_name]) + return param, new_grad + + +def gradient_clip_by_global_norm(clip_norm, + param_list=None, + group_name="default_group", + program=None): + if program is None: + program = framework.default_main_program() + if param_list is None: + param_list = program.block(0).all_parameters() + if all(isinstance(elem, basestring) for elem in param_list): + param_list = [program.block(0).var(elem) for elem in param_list] + if not all(isinstance(elem, framework.Parameter) for elem in param_list): + raise TypeError( + "'param_list' should be a list of Parameter or basestring(parameter's name)." + ) + + for param in param_list: + param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, + group_name) + + def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] for p, g in param_grad: - clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) + clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) if clip_attr is None: clip_attr = NullGradientClipAttr() if not isinstance(clip_attr, BaseGradientClipAttr): raise TypeError( - "clip attribute should be an instance of BaseGradientClippingAttr" - ) + "clip attribute should be an instance of BaseGradientClipAttr") - clip_attr.process_context(context=context, p_g=param_grad) + clip_attr.process_context(context=context, param=p, grad=g) create_op_callbacks.append( functools.partial( clip_attr.create_operators, param=p, grad=g)) diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index f87666545893884b2664b52be2d1a11b5aa4a244..4d8343e7de9526d527ebe93f334b59108d5ace8e 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -780,7 +780,7 @@ class Block(object): trainable=p.trainable, optimize_attr=p.optimize_attr, regularizer=p.regularizer, - clip_attr=p.clip_attr, + gradient_clip_attr=p.gradient_clip_attr, error_clip=p.error_clip, name=v.name) self.vars[new_p.name] = new_p @@ -948,7 +948,7 @@ class Parameter(Variable): self.regularizer = kwargs.get('regularizer', None) - self.clip_attr = kwargs.get('clip_attr', None) + self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None) # program is a global instance. diff --git a/python/paddle/v2/fluid/layers/__init__.py b/python/paddle/v2/fluid/layers/__init__.py index cc8a1b1ce5677c06afcd742ff94325bec20cdedd..a83dd3db74aed548a324a1c605723c957fca8604 100644 --- a/python/paddle/v2/fluid/layers/__init__.py +++ b/python/paddle/v2/fluid/layers/__init__.py @@ -24,6 +24,8 @@ import control_flow from control_flow import * import device from device import * +import math_op_patch +from math_op_patch import * __all__ = [] __all__ += nn.__all__ @@ -32,3 +34,4 @@ __all__ += tensor.__all__ __all__ += control_flow.__all__ __all__ += ops.__all__ __all__ += device.__all__ +__all__ += math_op_patch.__all__ diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 2f1188c542cc0208a189511a1eef1eddc411007c..5f01fdb076d3bf7d060a805d1431f4973993a843 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -11,22 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib -from ..layer_helper import LayerHelper, unique_name -from ..framework import Program, Variable, Operator -from .. import core +from layer_function_generator import autodoc from tensor import assign, fill_constant -import contextlib -from ..registry import autodoc +from .. import core +from ..framework import Program, Variable, Operator +from ..layer_helper import LayerHelper, unique_name __all__ = [ - 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', - 'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While', - 'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array', - 'array_to_lod_tensor', 'increment', 'array_write', 'create_array', - 'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse', - 'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank', - 'ParallelDo', 'Print' + 'split_lod_tensor', + 'merge_lod_tensor', + 'BlockGuard', + 'BlockGuardWithCompletion', + 'StaticRNNMemoryLink', + 'WhileGuard', + 'While', + 'lod_rank_table', + 'max_sequence_len', + 'topk', + 'lod_tensor_to_array', + 'array_to_lod_tensor', + 'increment', + 'array_write', + 'create_array', + 'less_than', + 'array_read', + 'shrink_memory', + 'array_length', + 'IfElse', + 'DynamicRNN', + 'ConditionalBlock', + 'StaticRNN', + 'reorder_lod_tensor_by_rank', + 'ParallelDo', + 'Print', ] @@ -1458,7 +1477,7 @@ class DynamicRNN(object): method)) -@autodoc +@autodoc() def reorder_lod_tensor_by_rank(x, rank_table): helper = LayerHelper('reorder_lod_tensor_by_rank', **locals()) helper.is_instance('x', Variable) diff --git a/python/paddle/v2/fluid/layers/device.py b/python/paddle/v2/fluid/layers/device.py index 736813d1b109087da367666d90be9e88dad1860e..107511b5f4ab1108610bc1326f30e5d9ab407853 100644 --- a/python/paddle/v2/fluid/layers/device.py +++ b/python/paddle/v2/fluid/layers/device.py @@ -15,14 +15,14 @@ All util layers. """ -from ..layer_helper import LayerHelper +from layer_function_generator import autodoc from ..framework import unique_name -from ..registry import autodoc +from ..layer_helper import LayerHelper __all__ = ['get_places'] -@autodoc +@autodoc() def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) out_places = helper.create_variable(name=unique_name(helper.name + ".out")) diff --git a/python/paddle/v2/fluid/registry.py b/python/paddle/v2/fluid/layers/layer_function_generator.py similarity index 94% rename from python/paddle/v2/fluid/registry.py rename to python/paddle/v2/fluid/layers/layer_function_generator.py index ff10542d40aabaf31897842754d38b7868472b21..b0e4d1635f7b5d0afdfa677e6ec1e8f9245a9d54 100644 --- a/python/paddle/v2/fluid/registry.py +++ b/python/paddle/v2/fluid/layers/layer_function_generator.py @@ -13,17 +13,19 @@ # limitations under the License. import re import cStringIO -import warnings import functools -import inspect +import warnings + +from .. import proto -import proto.framework_pb2 as framework_pb2 -from framework import OpProtoHolder, Variable, Program, Operator -from paddle.v2.fluid.layer_helper import LayerHelper, unique_name +framework_pb2 = proto.framework_pb2 + +from ..framework import OpProtoHolder, Variable +from ..layer_helper import LayerHelper __all__ = [ 'deprecated', - 'register_layer', + 'generate_layer_fn', 'autodoc', ] @@ -96,7 +98,7 @@ def _generate_doc_string_(op_proto): return buf.getvalue() -def register_layer(op_type): +def generate_layer_fn(op_type): """Register the Python layer for an Operator. Args: @@ -207,7 +209,10 @@ def deprecated(func_or_class): return func_wrapper -def autodoc(func): - func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto( - func.__name__)) - return func +def autodoc(comment=""): + def __impl__(func): + func.__doc__ = _generate_doc_string_(OpProtoHolder.instance( + ).get_op_proto(func.__name__)) + comment + return func + + return __impl__ diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..11197b70a3d4cae08afbb49ad31013ab40e4dad2 --- /dev/null +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -0,0 +1,152 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..framework import Variable, unique_name +from ..registry import OpProtoHolder + +__all__ = ['monkey_patch_variable'] + + +def monkey_patch_variable(): + def unique_tmp_name(): + return unique_name("tmp") + + def safe_get_dtype(var): + try: + dtype = var.dtype + except: + raise ValueError("Cannot get data type from %s", var.name) + return dtype + + def create_tensor(block, value, dtype, shape): + value = float(value) + tmp_name = unique_tmp_name() + var = block.create_var(name=tmp_name, shape=shape, dtype=dtype) + block.append_op( + type="fill_constant", + outputs={'Out': [var]}, + attrs={'dtype': var.dtype, + 'shape': shape, + 'value': value}) + return var + + def create_scalar(block, value, dtype): + return create_tensor(block, value, dtype, shape=[1]) + + def create_tensor_with_batchsize(ref_var, value, dtype): + assert isinstance(ref_var, Variable) + value = float(value) + tmp_name = unique_tmp_name() + var = ref_var.block.create_var(name=tmp_name, dtype=dtype) + ref_var.block.append_op( + type='fill_constant_batch_size_like', + outputs={'Out': [var]}, + inputs={'Input': [ref_var]}, + attrs={'shape': ref_var.shape, + 'value': value}) + return var + + def astype(self, dtype): + """ + Cast a variable to a specified data type. + NOTE: The variable must be a Tensor + Args: + self(Variable): The source variable + dtype: The target dtype + + Returns: + Variable with new dtype + """ + tmp_name = unique_tmp_name() + out = self.block.create_var(name=tmp_name, dtype=dtype) + self.block.append_op( + type="cast", + inputs={"X": [self]}, + outputs={"Out": [out]}, + attrs={"in_dtype": self.dtype, + "out_dtype": out.dtype}) + return out + + def _elemwise_method_creator_(method_name, op_type, reverse=False): + def __impl__(self, other_var): + lhs_dtype = safe_get_dtype(self) + + if not isinstance(other_var, Variable): + if reverse: + has_batch_size = False + for elem in self.shape: + if elem < 0: + has_batch_size = True + break + if not has_batch_size: + other_var = create_tensor( + self.block, + other_var, + dtype=lhs_dtype, + shape=self.shape) + else: + other_var = create_tensor_with_batchsize( + self, other_var, lhs_dtype) + else: + # add fill_op to self.block + other_var = create_scalar( + self.block, value=other_var, dtype=lhs_dtype) + + rhs_dtype = safe_get_dtype(other_var) + if lhs_dtype != rhs_dtype: + other_var = astype(other_var, lhs_dtype) + if reverse: + tmp = self + self = other_var + other_var = tmp + + tmp_name = unique_tmp_name() + out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) + self.block.append_op( + type=op_type, + inputs={'X': [self], + 'Y': [other_var]}, + outputs={'Out': out}) + return out + + comment = OpProtoHolder.instance().get_op_proto(op_type).comment + + __impl__.__doc__ = """ + {0} + Args: + self(Variable): left hand variable + other_var(Variable|float|int): right hand variable + + Returns: + Variable + """.format(comment) + __impl__.__name__ = method_name + return __impl__ + + # inject methods + for method_name, op_type, reverse in ( + ("__add__", "elementwise_add", False), + # a+b == b+a. Do not need to reverse explicitly + ("__radd__", "elementwise_add", False), + ("__sub__", "elementwise_sub", False), + ("__rsub__", "elementwise_sub", True), + ("__mul__", "elementwise_mul", False), + # a*b == b*a. Do not need to reverse explicitly + ("__rmul__", "elementwise_mul", False), + ("__div__", "elementwise_div", False), + ("__rdiv__", "elementwise_div", True)): + setattr(Variable, method_name, + _elemwise_method_creator_(method_name, op_type, reverse)) + + Variable.astype = astype diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 7716052a5cae0d8883dd9996f34d39e97f4399ea..b517f8be6a3e5558dd01afe094fb3989cfb3af44 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from ..registry import register_layer +from layer_function_generator import generate_layer_fn __activations__ = [ 'sigmoid', @@ -46,21 +45,11 @@ __activations__ = [ ] __all__ = [ - 'mean', - 'mul', - 'reshape', - 'scale', - 'transpose', - 'sigmoid_cross_entropy_with_logits', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_max', - 'elementwise_min', - 'clip', - 'sequence_softmax', + 'mean', 'mul', 'reshape', 'scale', 'transpose', + 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', + 'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min', + 'clip', 'clip_by_norm', 'sequence_softmax' ] + __activations__ for _OP in set(__all__): - globals()[_OP] = register_layer(_OP) + globals()[_OP] = generate_layer_fn(_OP) diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 26e9111f6f31019ab14780cc4bea01e617561fb7..dcca8b6c547d10864ff4cd0af1c217d89e3b522f 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -25,13 +25,13 @@ class ParamAttr(object): learning_rate=1.0, regularizer=None, trainable=True, - clip=None): + gradient_clip=None): self.name = name self.initializer = initializer self.learning_rate = learning_rate self.regularizer = regularizer self.trainable = trainable - self.clip = clip + self.gradient_clip = gradient_clip def set_default_initializer(self, initializer): if initializer is None: @@ -77,7 +77,7 @@ class ParamAttr(object): }, 'regularizer': self.regularizer, 'trainable': self.trainable, - 'clip_attr': self.clip + 'gradient_clip_attr': self.gradient_clip } if with_initializer: kwargs['initializer'] = self.initializer diff --git a/python/paddle/v2/fluid/tests/CMakeLists.txt b/python/paddle/v2/fluid/tests/CMakeLists.txt index 9a0240cbf65c7a79e29babc2abcb157ada684c5e..83053160820a70bb5e54f721c0d7b881c5765004 100644 --- a/python/paddle/v2/fluid/tests/CMakeLists.txt +++ b/python/paddle/v2/fluid/tests/CMakeLists.txt @@ -6,3 +6,4 @@ endforeach() add_subdirectory(book) add_subdirectory(book_distribute) +add_subdirectory(book_memory_optimization) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index be22e97054a16d69cf9e2d1e88629497e519c778..8776a65bf804e93dfeb295ecca34fac0840b0a90 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -27,7 +27,7 @@ hidden1 = fluid.layers.fc(input=image, act='relu', param_attr=fluid.ParamAttr( regularizer=regularizer, - clip=fluid.clip.ClipByValue(10))) + gradient_clip=fluid.clip.ClipByValue(10))) hidden2 = fluid.layers.fc(input=hidden1, size=64, diff --git a/python/paddle/v2/fluid/tests/book_memory_optimization/CMakeLists.txt b/python/paddle/v2/fluid/tests/book_memory_optimization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..213af5d27f711214feda3d200ced57bf71fbf6c2 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book_memory_optimization/CMakeLists.txt @@ -0,0 +1,11 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +list(REMOVE_ITEM TEST_OPS test_memopt_image_classification_train) +py_test(test_memopt_image_classification_train_resnet SRCS test_memopt_image_classification_train.py ARGS resnet) +py_test(test_memopt_image_classification_train_vgg SRCS test_memopt_image_classification_train.py ARGS vgg) + +# default test +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py new file mode 100644 index 0000000000000000000000000000000000000000..cf054bb0fe778d34add4ac456f672a8b47483e84 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +x = fluid.layers.data(name='x', shape=[13], dtype='float32') + +y_predict = fluid.layers.fc(input=x, size=1, act=None) + +y = fluid.layers.data(name='y', shape=[1], dtype='float32') + +cost = fluid.layers.square_error_cost(input=y_predict, label=y) +avg_cost = fluid.layers.mean(x=cost) + +sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) +sgd_optimizer.minimize(avg_cost) + +# memopt_program = fluid.default_main_program() +memopt_program = fluid.memory_optimize(fluid.default_main_program()) + +BATCH_SIZE = 200 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) +exe = fluid.Executor(place) + +exe.run(fluid.default_startup_program()) + +PASS_NUM = 100 +for pass_id in range(PASS_NUM): + fluid.io.save_persistables(exe, "./fit_a_line.model/") + fluid.io.load_persistables(exe, "./fit_a_line.model/") + for data in train_reader(): + avg_loss_value, = exe.run(memopt_program, + feed=feeder.feed(data), + fetch_list=[avg_cost]) + + if avg_loss_value[0] < 10.0: + exit(0) # if avg cost less than 10.0, we think our code is good. +exit(1) diff --git a/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py new file mode 100644 index 0000000000000000000000000000000000000000..42b3cb81ce67d38494677f3ecbfb1e07f7c0c3ad --- /dev/null +++ b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -0,0 +1,147 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + + +def resnet_cifar10(input, depth=32): + def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=False) + return fluid.layers.batch_norm(input=tmp, act=act) + + def shortcut(input, ch_in, ch_out, stride): + if ch_in != ch_out: + return conv_bn_layer(input, ch_out, 1, stride, 0, None) + else: + return input + + def basicblock(input, ch_in, ch_out, stride): + tmp = conv_bn_layer(input, ch_out, 3, stride, 1) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) + short = shortcut(input, ch_in, ch_out, stride) + return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') + + def layer_warp(block_func, input, ch_in, ch_out, count, stride): + tmp = block_func(input, ch_in, ch_out, stride) + for i in range(1, count): + tmp = block_func(tmp, ch_out, ch_out, 1) + return tmp + + assert (depth - 2) % 6 == 0 + n = (depth - 2) / 6 + conv1 = conv_bn_layer( + input=input, ch_out=16, filter_size=3, stride=1, padding=1) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) + pool = fluid.layers.pool2d( + input=res3, pool_size=8, pool_type='avg', pool_stride=1) + return pool + + +def vgg16_bn_drop(input): + def conv_block(input, num_filter, groups, dropouts): + return fluid.nets.img_conv_group( + input=input, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act='relu', + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type='max') + + conv1 = conv_block(input, 64, 2, [0.3, 0]) + conv2 = conv_block(conv1, 128, 2, [0.4, 0]) + conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) + conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) + conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) + + drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) + fc1 = fluid.layers.fc(input=drop, size=512, act=None) + bn = fluid.layers.batch_norm(input=fc1, act='relu') + drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) + fc2 = fluid.layers.fc(input=drop2, size=512, act=None) + return fc2 + + +classdim = 10 +data_shape = [3, 32, 32] + +images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') +label = fluid.layers.data(name='label', shape=[1], dtype='int64') + +net_type = "vgg" +if len(sys.argv) >= 2: + net_type = sys.argv[1] + +if net_type == "vgg": + print("train vgg net") + net = vgg16_bn_drop(images) +elif net_type == "resnet": + print("train resnet") + net = resnet_cifar10(images, 32) +else: + raise ValueError("%s network is not supported" % net_type) + +predict = fluid.layers.fc(input=net, size=classdim, act='softmax') +cost = fluid.layers.cross_entropy(input=predict, label=label) +avg_cost = fluid.layers.mean(x=cost) + +optimizer = fluid.optimizer.Adam(learning_rate=0.001) +opts = optimizer.minimize(avg_cost) + +accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + +# memopt_program = fluid.default_main_program() +memopt_program = fluid.memory_optimize(fluid.default_main_program()) + +BATCH_SIZE = 128 +PASS_NUM = 1 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.cifar.train10(), buf_size=128 * 10), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) +feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) +exe.run(fluid.default_startup_program()) + +for pass_id in range(PASS_NUM): + accuracy.reset(exe) + for data in train_reader(): + loss, acc = exe.run(memopt_program, + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) + pass_acc = accuracy.eval(exe) + print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( + pass_acc)) + # this model is slow, so if we can train two mini batch, we think it works properly. + exit(0) +exit(1) diff --git a/python/paddle/v2/fluid/tests/test_compare_op.py b/python/paddle/v2/fluid/tests/test_compare_op.py index 08ef90b10eb70eb380db2821415184709602b848..c9be80fc45cd3428937998357b9dd9cbde1547cc 100644 --- a/python/paddle/v2/fluid/tests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/test_compare_op.py @@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) - create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) - create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) if __name__ == '__main__': diff --git a/python/paddle/v2/fluid/tests/test_clip.py b/python/paddle/v2/fluid/tests/test_error_clip.py similarity index 100% rename from python/paddle/v2/fluid/tests/test_clip.py rename to python/paddle/v2/fluid/tests/test_error_clip.py diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6e6a1ef6961d8f087dfc1ac5a4c4a8ad90032e --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +BATCH_SIZE = 128 +CLIP = 1 + +prog = fluid.framework.Program() +with fluid.program_guard(main_program=prog): + image = fluid.layers.data(name='x', shape=[784], dtype='float32') + + hidden1 = fluid.layers.fc(input=image, size=128, act='relu') + hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') + predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') + + label = fluid.layers.data(name='y', shape=[1], dtype='int64') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + +prog_clip = prog.clone() + +avg_cost_clip = prog_clip.block(0).var(avg_cost.name) + +p_g = fluid.backward.append_backward(loss=avg_cost) +p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) + +with fluid.program_guard(main_program=prog_clip): + fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) + p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) + +grad_list = [elem[1] for elem in p_g] +grad_clip_list = [elem[1] for elem in p_g_clip] + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) +feeder = fluid.DataFeeder(feed_list=[image, label], place=place) +exe.run(fluid.default_startup_program()) + +count = 0 +for data in train_reader(): + count += 1 + if count > 5: + break + out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) + out_clip = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=grad_clip_list) + global_norm = 0 + for v in out[1:]: + global_norm += np.sum(np.power(v, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in out_clip[1:]: + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + + if not np.isclose( + a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3): + exit(1) +exit(0) diff --git a/python/paddle/v2/fluid/tests/test_math_op_patch.py b/python/paddle/v2/fluid/tests/test_math_op_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..2e77639a4c886327cc8dc7053fc6c0f6c6e9dcc9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_math_op_patch.py @@ -0,0 +1,181 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import decorators +import paddle.v2.fluid as fluid +import numpy + + +class TestMathOpPatches(unittest.TestCase): + @decorators.prog_scope() + def test_add_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = a + 10 + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(a_np + 10, b_np)) + + @decorators.prog_scope() + def test_radd_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = 10 + a + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(a_np + 10, b_np)) + + @decorators.prog_scope() + def test_sub_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = a - 10 + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(a_np - 10, b_np)) + + @decorators.prog_scope() + def test_radd_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = 10 - a + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(10 - a_np, b_np)) + + @decorators.prog_scope() + def test_mul_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = a * 10 + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(a_np * 10, b_np)) + + @decorators.prog_scope() + def test_rmul_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = 10 * a + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(10 * a_np, b_np)) + + @decorators.prog_scope() + def test_div_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = a / 10 + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(a_np / 10, b_np)) + + @decorators.prog_scope() + def test_rdiv_scalar(self): + a = fluid.layers.data(name="a", shape=[1]) + b = 10 / a + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + 1e-2 + + b_np = exe.run(fluid.default_main_program(), + feed={"a": a_np}, + fetch_list=[b]) + self.assertTrue(numpy.allclose(10 / a_np, b_np)) + + @decorators.prog_scope() + def test_div_two_tensor(self): + a = fluid.layers.data(name="a", shape=[1]) + b = fluid.layers.data(name="b", shape=[1]) + c = a / b + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = numpy.random.random(size=[10, 1]).astype('float32') + 1e-2 + c_np = exe.run(fluid.default_main_program(), + feed={"a": a_np, + 'b': b_np}, + fetch_list=[c]) + self.assertTrue(numpy.allclose(a_np / b_np, c_np)) + + @decorators.prog_scope() + def test_mul_two_tensor(self): + a = fluid.layers.data(name="a", shape=[1]) + b = fluid.layers.data(name="b", shape=[1]) + c = a * b + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = numpy.random.random(size=[10, 1]).astype('float32') + c_np = exe.run(fluid.default_main_program(), + feed={"a": a_np, + 'b': b_np}, + fetch_list=[c]) + self.assertTrue(numpy.allclose(a_np * b_np, c_np)) + + @decorators.prog_scope() + def test_add_two_tensor(self): + a = fluid.layers.data(name="a", shape=[1]) + b = fluid.layers.data(name="b", shape=[1]) + c = a + b + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = numpy.random.random(size=[10, 1]).astype('float32') + c_np = exe.run(fluid.default_main_program(), + feed={"a": a_np, + 'b': b_np}, + fetch_list=[c]) + self.assertTrue(numpy.allclose(a_np + b_np, c_np)) + + @decorators.prog_scope() + def test_sub_two_tensor(self): + a = fluid.layers.data(name="a", shape=[1]) + b = fluid.layers.data(name="b", shape=[1]) + c = a - b + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.random.random(size=[10, 1]).astype('float32') + b_np = numpy.random.random(size=[10, 1]).astype('float32') + c_np = exe.run(fluid.default_main_program(), + feed={"a": a_np, + 'b': b_np}, + fetch_list=[c]) + self.assertTrue(numpy.allclose(a_np - b_np, c_np)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_registry.py b/python/paddle/v2/fluid/tests/test_registry.py index 6435e7e243d4e7fa10c99fda48a011523d8cc588..44e50ca55ac609ed2e0a145ff12248fa18479668 100644 --- a/python/paddle/v2/fluid/tests/test_registry.py +++ b/python/paddle/v2/fluid/tests/test_registry.py @@ -11,26 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest -import warnings import paddle.v2.fluid as fluid -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -import paddle.v2.fluid.registry as registry +import numpy as np +import decorators class TestRegistry(unittest.TestCase): + @decorators.prog_scope() def test_registry_layer(self): - self.layer_type = "mean" - program = framework.Program() - x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32') - output = layers.mean(x) + output = fluid.layers.mean(x=x) + place = fluid.CPUPlace() exe = fluid.Executor(place) - X = np.random.random((10, 10)).astype("float32") - mean_out = exe.run(program, feed={"X": X}, fetch_list=[output]) - self.assertAlmostEqual(np.mean(X), mean_out) + mean_out = exe.run(feed={"X": X}, fetch_list=[output]) + self.assertAlmostEqual(np.mean(X), mean_out[0])