diff --git a/doc/fluid/api/initializer.rst b/doc/fluid/api/initializer.rst index dc0b52b14fd242dfaded1cb9a8e0ab9eb66b0607..96682c8f9fbb8683ad690a7c5809865e37a1920e 100644 --- a/doc/fluid/api/initializer.rst +++ b/doc/fluid/api/initializer.rst @@ -32,6 +32,15 @@ Normal :members: :noindex: +.. _api_fluid_initializer_Normal: + +TruncatedNormal +------ + +.. autoclass:: paddle.fluid.initializer.TruncatedNormal + :members: + :noindex: + .. _api_fluid_initializer_Xavier: Xavier diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ca12ec7cb73a37794b4facd8c2e175e9281cf4a2..534acffa7ffdc07bfcbfb328b540221c2b988de5 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -79,6 +79,7 @@ paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)) paddle.fluid.initializer.NormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)) +paddle.fluid.initializer.TruncatedNormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)) paddle.fluid.initializer.XavierInitializer.__init__ ArgSpec(args=['self', 'uniform', 'fan_in', 'fan_out', 'seed'], varargs=None, keywords=None, defaults=(True, None, None, 0)) paddle.fluid.initializer.BilinearInitializer.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.initializer.MSRAInitializer.__init__ ArgSpec(args=['self', 'uniform', 'fan_in', 'seed'], varargs=None, keywords=None, defaults=(True, None, 0)) diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d854e2803975543b51c50ea2bc173322d3c3ca5e --- /dev/null +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -0,0 +1,255 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e +template +T Erfinv(T x) { + if (x < -1 || x > 1) { + return std::numeric_limits::quiet_NaN(); + } else if (x == 1.0) { + return std::numeric_limits::infinity(); + } else if (x == -1.0) { + return -std::numeric_limits::infinity(); + } + + const T LN2 = 6.931471805599453094172321214581e-1; + + const T A0 = 1.1975323115670912564578e0; + const T A1 = 4.7072688112383978012285e1; + const T A2 = 6.9706266534389598238465e2; + const T A3 = 4.8548868893843886794648e3; + const T A4 = 1.6235862515167575384252e4; + const T A5 = 2.3782041382114385731252e4; + const T A6 = 1.1819493347062294404278e4; + const T A7 = 8.8709406962545514830200e2; + + const T B0 = 1.0000000000000000000e0; + const T B1 = 4.2313330701600911252e1; + const T B2 = 6.8718700749205790830e2; + const T B3 = 5.3941960214247511077e3; + const T B4 = 2.1213794301586595867e4; + const T B5 = 3.9307895800092710610e4; + const T B6 = 2.8729085735721942674e4; + const T B7 = 5.2264952788528545610e3; + + const T C0 = 1.42343711074968357734e0; + const T C1 = 4.63033784615654529590e0; + const T C2 = 5.76949722146069140550e0; + const T C3 = 3.64784832476320460504e0; + const T C4 = 1.27045825245236838258e0; + const T C5 = 2.41780725177450611770e-1; + const T C6 = 2.27238449892691845833e-2; + const T C7 = 7.74545014278341407640e-4; + + const T D0 = 1.4142135623730950488016887e0; + const T D1 = 2.9036514445419946173133295e0; + const T D2 = 2.3707661626024532365971225e0; + const T D3 = 9.7547832001787427186894837e-1; + const T D4 = 2.0945065210512749128288442e-1; + const T D5 = 2.1494160384252876777097297e-2; + const T D6 = 7.7441459065157709165577218e-4; + const T D7 = 1.4859850019840355905497876e-9; + + const T E0 = 6.65790464350110377720e0; + const T E1 = 5.46378491116411436990e0; + const T E2 = 1.78482653991729133580e0; + const T E3 = 2.96560571828504891230e-1; + const T E4 = 2.65321895265761230930e-2; + const T E5 = 1.24266094738807843860e-3; + const T E6 = 2.71155556874348757815e-5; + const T E7 = 2.01033439929228813265e-7; + + const T F0 = 1.414213562373095048801689e0; + const T F1 = 8.482908416595164588112026e-1; + const T F2 = 1.936480946950659106176712e-1; + const T F3 = 2.103693768272068968719679e-2; + const T F4 = 1.112800997078859844711555e-3; + const T F5 = 2.611088405080593625138020e-5; + const T F6 = 2.010321207683943062279931e-7; + const T F7 = 2.891024605872965461538222e-15; + + T abs_x = abs(x); + + if (abs_x <= 0.85) { + T r = 0.180625 - 0.25 * x * x; + T num = + (((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) * + r + + A0); + T den = + (((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) * + r + + B0); + return x * num / den; + } + + T r = sqrt(LN2 - log(1.0 - abs_x)); + + T num, den; + if (r <= 5.0) { + r = r - 1.6; + num = + (((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) * + r + + C0); + den = + (((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) * + r + + D0); + } else { + r = r - 5.0; + num = + (((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) * + r + + E0); + den = + (((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) * + r + + F0); + } + + if (x < 0) { + return -num / den; + } else { + return num / den; + } +} + +template +struct TruncatedNormal { + T mean, std; + T a_normal_cdf; + T b_normal_cdf; + TruncatedNormal(T mean, T std) : mean(mean), std(std) { + auto normal_cdf = [](T x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; + }; + a_normal_cdf = normal_cdf(-2.0); + b_normal_cdf = normal_cdf(2.0); + } + + T operator()(T value) const { + auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; + return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std; + } +}; + +template +class CPUTruncatedGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.Attr("mean"); + float std = context.Attr("std"); + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + + unsigned int seed = static_cast(context.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist(std::numeric_limits::min(), + 1.0); + TruncatedNormal truncated_normal(mean, std); + int64_t size = tensor->numel(); + for (int64_t i = 0; i < size; ++i) { + data[i] = truncated_normal(dist(engine)); + } + } +}; + +class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of TruncatedGaussianRandomOp should not be null."); + auto shape = ctx->Attrs().Get>("shape"); + std::vector out_dim; + out_dim.reserve(shape.size()); + for (auto dim : shape) { + out_dim.push_back(static_cast(dim)); + } + PADDLE_ENFORCE(shape.size() > 0UL, + "shape can be one int or array. shape must be set."); + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library{framework::LibraryType::kPlain}; + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + return framework::OpKernelType( + static_cast(ctx.Attr("dtype")), + ctx.device_context(), layout, library); + } +}; + +class TruncatedGaussianRandomOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", "Output tensor of truncated gaussian random op."); + + AddAttr>("shape", + "(vector) " + "The dimension of random tensor."); + AddAttr("mean", + "(float, default 0.0) " + "mean of random tensor.") + .SetDefault(.0f); + AddAttr("std", + "(float, default 1.0) " + "std of random tensor.") + .SetDefault(1.0f); + AddAttr("seed", + "(int, default 0) " + "Random seed of generator." + "0 means use system wide seed." + "Note that if seed is not 0, this operator will always " + "generate the same random numbers every time.") + .SetDefault(0); + AddAttr("dtype", + "(int, default 5(FP32)) " + "Output data type.") + .SetDefault(framework::proto::VarType::FP32); + AddComment(R"DOC( +TruncatedGaussianRandom Operator. + +Used to initialize tensors with truncated gaussian random generator. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random, + ops::TruncatedGaussianRandomOp, + ops::TruncatedGaussianRandomOpMaker); +REGISTER_OP_CPU_KERNEL(truncated_gaussian_random, + ops::CPUTruncatedGaussianRandomKernel); diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cu b/paddle/fluid/operators/truncated_gaussian_random_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ad2a9021bfe344d838dff2040b3fb9371274e218 --- /dev/null +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cu @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct TruncatedNormal { + T mean, std; + T a_normal_cdf; + T b_normal_cdf; + unsigned int seed; + T numeric_min; + + __host__ __device__ TruncatedNormal(T mean, T std, T numeric_min, int seed) + : mean(mean), std(std), seed(seed), numeric_min(numeric_min) { + a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0; + b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0; + } + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(numeric_min, 1); + rng.discard(n); + T value = dist(rng); + auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; + return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std; + } +}; + +template +class GPUTruncatedGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = static_cast(context.Attr("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T mean = static_cast(context.Attr("mean")); + T std = static_cast(context.Attr("std")); + thrust::counting_iterator index_sequence_begin(0); + int64_t size = tensor->numel(); + thrust::transform( + index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(data), + TruncatedNormal(mean, std, std::numeric_limits::min(), seed)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + truncated_gaussian_random, + paddle::operators::GPUTruncatedGaussianRandomKernel); diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index bd46ed8e50c9344d471578eb0f89b7e214d62722..7a7a0078a557c47492a4396897aafabe6c9c5dcb 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -20,10 +20,10 @@ import contextlib from .core import VarDesc __all__ = [ - 'Constant', 'Uniform', 'Normal', 'Xavier', 'Bilinear', 'MSRA', - 'force_init_on_cpu', 'init_on_cpu', 'ConstantInitializer', - 'UniformInitializer', 'NormalInitializer', 'XavierInitializer', - 'BilinearInitializer', 'MSRAInitializer' + 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', + 'MSRA', 'force_init_on_cpu', 'init_on_cpu', 'ConstantInitializer', + 'UniformInitializer', 'NormalInitializer', 'TruncatedNormalInitializer', + 'XavierInitializer', 'BilinearInitializer', 'MSRAInitializer' ] _force_init_on_cpu_ = False @@ -33,6 +33,8 @@ def force_init_on_cpu(): """ The flag of whether force to init variables on CPU. + Returns:: + Examples: .. code-block:: python @@ -272,6 +274,60 @@ class NormalInitializer(Initializer): return op +class TruncatedNormalInitializer(Initializer): + """Implements the Random TruncatedNormal(Gaussian) distribution initializer + + Args: + loc (float): mean of the normal distribution + scale (float): standard deviation of the normal distribution + seed (int): random seed + + Examples: + .. code-block:: python + + fc = fluid.layers.fc(input=x, size=10, + param_attr=fluid.initializer.TruncatedNormal(loc=0.0, scale=2.0)) + """ + + def __init__(self, loc=0.0, scale=1.0, seed=0): + assert loc is not None + assert scale is not None + assert seed is not None + super(NormalInitializer, self).__init__() + self._mean = loc + self._std_dev = scale + self._seed = seed + + def __call__(self, var, block): + """Add truncated normal distribution initialization ops for a variable + + Args: + var: Variable that needs to be initialized + block: The block in which initialization ops + should be added + + Returns: + the initialization op + """ + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + # Initialization Ops should be prepended and not appended + if self._seed == 0: + self._seed = block.program.random_seed + op = block._prepend_op( + type="truncated_gaussian_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "dtype": int(var.dtype), + "mean": self._mean, + "std": self._std_dev, + "seed": self._seed + }) + var.op = op + return op + + class XavierInitializer(Initializer): """ This class implements the Xavier weight initializer from the paper @@ -583,6 +639,7 @@ class BilinearInitializer(Initializer): Constant = ConstantInitializer Uniform = UniformInitializer Normal = NormalInitializer +TruncatedNormal = TruncatedNormalInitializer Xavier = XavierInitializer MSRA = MSRAInitializer Bilinear = BilinearInitializer diff --git a/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py new file mode 100644 index 0000000000000000000000000000000000000000..4abeae77d26e8def85596aefc6c2f89cd4e4d6f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_truncated_gaussian_random_op.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy + +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from paddle.fluid.executor import Executor + + +class TestTrunctedGaussianRandomOp(unittest.TestCase): + def setUp(self): + self.op_type = "truncated_gaussian_random" + self.inputs = {} + self.attrs = { + "shape": [10000], + "mean": .0, + "std": 1., + "seed": 10, + } + + self.outputs = ["Out"] + + def test_cpu(self): + self.gaussian_random_test(place=fluid.CPUPlace()) + + def test_gpu(self): + if core.is_compiled_with_cuda(): + self.gaussian_random_test(place=fluid.CUDAPlace(0)) + + def gaussian_random_test(self, place): + + program = fluid.Program() + block = program.global_block() + vout = block.create_var(name="Out") + op = block.append_op( + type=self.op_type, outputs={"Out": vout}, attrs=self.attrs) + + op.desc.infer_var_type(block.desc) + op.desc.infer_shape(block.desc) + + fetch_list = [] + for var_name in self.outputs: + fetch_list.append(block.var(var_name)) + + exe = Executor(place) + outs = exe.run(program, fetch_list=fetch_list) + tensor = outs[0] + self.assertAlmostEqual(numpy.mean(tensor), .0, delta=0.1) + self.assertAlmostEqual(numpy.var(tensor), 0.773, delta=0.1) + + +if __name__ == "__main__": + unittest.main()