diff --git a/CMakeLists.txt b/CMakeLists.txt index b174831109372cb014741d63032fa6a470e74042..c7d743e193e7d32dbc0b56f3bcb05b6c61f85f1d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,8 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) -option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) -option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a48c5bd8cc0b02a0ed1e5d0b6640e56f8fb83753..c82811bfa56f423da4eb4646685be1fad262bf7f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -46,5 +46,6 @@ cc_library(paddle_pybind SHARED cross_entropy_op recurrent_op uniform_random_op + gaussian_random_op fill_zeros_like_op) endif(WITH_PYTHON) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index d0419f07ba4dca32236128a2fa7afdc1460f9468..08b47cabd4c2225c50022bd35734dcc2663324d6 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include #include #include @@ -23,6 +22,7 @@ limitations under the License. */ #include "paddle/framework/framework.pb.h" #include "paddle/platform/enforce.h" +#include "paddle/platform/variant.h" namespace paddle { namespace framework { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 3ea3b499e5b1015a5b0b107850ac0608735ba729..e46cf1fc344446c86e8b50d7285b0e99eceea8e1 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include #include #include #include #include "paddle/framework/dim.h" #include "paddle/platform/enforce.h" +#include "paddle/platform/variant.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index fc5db7ce286d5ebe70b2de734f2fac454f1a0794..7a79229f540ccf087c8dc4b664fa62400f38cd08 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include -#include #include #include #include @@ -26,6 +25,7 @@ limitations under the License. */ #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/place.h" +#include "paddle/platform/variant.h" #include "paddle/utils/Error.h" namespace paddle { diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index d6ddd5deab1700db8961c771d375bb8ef7709af9..c7ee76500312aa1c0d3f219f000250127e815a96 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -40,7 +40,9 @@ USE_OP(softmax); USE_OP(rowwise_add); USE_OP(fill_zeros_like); USE_OP_WITHOUT_KERNEL(recurrent_op); +USE_OP(gaussian_random); USE_OP(uniform_random); + namespace paddle { namespace framework { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index b57958591fb752132407c35958db0781d0e023f0..cd1b4de426a49fa66dbbf8cf7d09990ac8d21227 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -79,11 +79,11 @@ class Tensor { inline const DDim& dims() const; /*! Resize the dimensions of the memory block. */ - inline void Resize(const DDim& dims); + inline Tensor& Resize(const DDim& dims); /*! The internal of two tensors share the same memory block. */ template - inline void ShareDataWith(const Tensor& src); + inline Tensor& ShareDataWith(const Tensor& src); /** * @brief Copy the content of external tensor to a new place. diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 8d9bec6dc9c3f0af822a0d8cd8588dc932970652..7d7263b899afb7a2128548f264065a8013b6f0c9 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -23,9 +23,11 @@ template inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tenosr holds no memory. Call Tensor::mutable_data first."); - PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_, - "Tensor's dims_ is out of bound. Call Tensor::mutable_data " - "first to re-allocate memory."); + PADDLE_ENFORCE_GE( + holder_->size(), product(dims_) * sizeof(T) + offset_, + "Tensor's dims_ is out of bound. Call Tensor::mutable_data " + "first to re-allocate memory.\n" + "or maybe the required data-type mismatches the data already stored."); } template @@ -78,9 +80,10 @@ inline T* Tensor::mutable_data(platform::Place place) { } template -inline void Tensor::ShareDataWith(const Tensor& src) { +inline Tensor& Tensor::ShareDataWith(const Tensor& src) { src.check_memory_size(); *this = src; + return *this; } template @@ -136,7 +139,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { return dst; } -inline void Tensor::Resize(const DDim& dims) { dims_ = dims; } +inline Tensor& Tensor::Resize(const DDim& dims) { + dims_ = dims; + return *this; +} inline const DDim& Tensor::dims() const { return dims_; } diff --git a/paddle/gserver/tests/test_KmaxSeqScore.cpp b/paddle/gserver/tests/test_KmaxSeqScore.cpp index a51fe390c74d74cd5f3d07df62b715b239335548..308abe6816428bc0f98ec32e892622fa4a23b1ae 100644 --- a/paddle/gserver/tests/test_KmaxSeqScore.cpp +++ b/paddle/gserver/tests/test_KmaxSeqScore.cpp @@ -96,6 +96,11 @@ TEST(Layer, kmaxSeqScoreLayer) { MatrixPtr inValue = Matrix::create(subSeqStartPosition.back(), 1, false, false); + std::vector mode = {false}; +#ifndef PADDLE_ONLY_CPU + mode.push_back(true); +#endif + for (auto hasSubseq : {false, true}) { vector> groundTruth; inValue->randomizeUniform(); @@ -104,7 +109,7 @@ TEST(Layer, kmaxSeqScoreLayer) { hasSubseq ? subSeqStartPosition : seqStartPosition, beamSize); - for (auto useGpu : {false, true}) { + for (auto useGpu : mode) { TestConfig config; config.layerConfig.set_type("kmax_seq_score"); config.layerConfig.set_beam_size(beamSize); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 1fd51219b169516d9fb99c1ecf1c2119f2c8c358..8f5c46be07911163d7346f55dd264d696850ab0c 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,6 +41,8 @@ function(op_library TARGET) endif() endfunction() +cc_test(gather_test SRCS gather_test.cc DEPS tensor) + cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) @@ -53,6 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h new file mode 100644 index 0000000000000000000000000000000000000000..0c73717d38aca9f3430e66cafc3ecccdd2eec776 --- /dev/null +++ b/paddle/operators/gather.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +// Implementation of CPU copy +template +void CPUGather(const T* params, const int* indices, const int slice_size, + const int index_size, T* output) { + const size_t slice_bytes = slice_size * sizeof(T); + + for (size_t i = 0; i < index_size; ++i) { + int index_ = indices[i]; + memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); + } +} + +// Implementation of GPU copy: +template +void GPUGather(const T* src, const int* index, const int slice_size, + const int index_size, T* output); + +/** + * Return a new tensor from source tensor, gathered according to index + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void Gather(const platform::Place& place, const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + paddle::framework::DDim output_dims(src_dims); + output_dims[0] = index_size; + + // slice size + int slice_size = 1; + for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + // Gathering + if (platform::is_cpu_place(place)) { + CPUGather(src->data(), index->data(), slice_size, index_size, + output->data()); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5de748ec461e4b1a34b75b57c9cd7d5bc9326059 --- /dev/null +++ b/paddle/operators/gather_test.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/gather.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +#include +#include +#include + +TEST(Gather, GatherData) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + Tensor* src = new Tensor(); + Tensor* index = new Tensor(); + Tensor* output = new Tensor(); + + int* p_src = nullptr; + int* p_index = nullptr; + p_src = src->mutable_data(make_ddim({3, 4}), CPUPlace()); + p_index = index->mutable_data(make_ddim({2}), CPUPlace()); + + for (size_t i = 0; i < 12; ++i) p_src[i] = i; + p_index[0] = 1; + p_index[1] = 0; + + int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); + + Gather(CPUPlace(), src, index, output); + + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); + for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); +} diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ef417ae2f06e8a9f10aed80674015e2ee448f4a3 --- /dev/null +++ b/paddle/operators/gaussian_random_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class GaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.op_.GetAttr("mean"); + float std = context.op_.GetAttr("std"); + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + + // TODO(dzh): attribute does not support unsigned int. + // And we need a global random seed configuration. + int seed = context.op_.GetAttr("seed"); + if (seed == 0) { + seed = std::random_device()(); + } + std::mt19937 g(seed); + std::normal_distribution distribution(mean, std); + ssize_t size = framework::product(tensor->dims()); + for (int i = 0; i < size; ++i) { + data[i] = distribution(g); + } + } +}; + +class GaussianRandomOp : public framework::OperatorWithKernel { + protected: + void InferShape(const framework::InferShapeContext& context) const override { + auto* tensor = context.Output(0); + auto dims = GetAttr>("dims"); + PADDLE_ENFORCE(dims.size() > 0UL, + "dims can be one int or array. dims must be set."); + tensor->Resize(framework::make_ddim(dims)); + } +}; + +class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GaussianRandomOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "output matrix of random op"); + AddComment(R"DOC( +GaussianRandom operator. +Use to initialize tensor with gaussian random generator. +)DOC"); + + AddAttr>("dims", "The dimension of random tensor."); + AddAttr("mean", "mean value of random.").SetDefault(.0f); + AddAttr("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr("seed", + "Random seed of generator." + "0 means use system wide seed") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0dd26f6df8a4cb5d00632560dbca452ade36495e --- /dev/null +++ b/paddle/operators/gaussian_random_op.cu @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "paddle/platform/dynload/curand.h" +#include "paddle/platform/gpu_info.h" + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class GaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.op_.GetAttr("mean"); + float std = context.op_.GetAttr("std"); + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + + int seed = context.op_.GetAttr("seed"); + if (seed == 0) { + seed = std::random_device()(); + } + curandGenerator_t g; + PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( + &g, CURAND_RNG_PSEUDO_DEFAULT)); + PADDLE_ENFORCE( + platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); + curandGenerateNormal(g, data, framework::product(tensor->dims()), mean, + std); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); diff --git a/paddle/platform/place.h b/paddle/platform/place.h index a82e8c942fa28297d91056a66b61f085f2bdb946..1117476bb37f1b0f3876c55e610803d5ee2558ce 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include #include +#include "paddle/platform/variant.h" namespace paddle { namespace platform { diff --git a/paddle/platform/variant.h b/paddle/platform/variant.h new file mode 100644 index 0000000000000000000000000000000000000000..c2257af1b5dd1a1e284979bf17e1a947072baa85 --- /dev/null +++ b/paddle/platform/variant.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include + +#ifndef PADDLE_ONLY_CPU + +// Because boost's variadic templates has bug on nvcc, boost will disable +// variadic template support when GPU enabled on nvcc. +// Define BOOST_NO_CXX11_VARIADIC_TEMPLATES on gcc/clang to generate same +// function symbols. +// +// https://github.com/PaddlePaddle/Paddle/issues/3386 +#ifndef BOOST_NO_CXX11_VARIADIC_TEMPLATES +#define BOOST_NO_CXX11_VARIADIC_TEMPLATES +#endif +#endif + +#include diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 10659caa882fd3d4060f9947413a392c3b681ee8..55ed724e8fc708b270fe931a3305b396f525eacd 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -21,5 +21,7 @@ py_test(gradient_checker SRCS gradient_checker.py) py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) + py_test(test_operator SRCS test_operator.py) +# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f95ed70b58d611b3233a21d3f2a34c864ae4d1b3 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -0,0 +1,36 @@ +import unittest +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator +import numpy + + +class GaussianRandomTest(unittest.TestCase): + def test_cpu(self): + self.gaussian_random_test(place=core.CPUPlace()) + + def test_gpu(self): + if core.is_compile_gpu(): + self.gaussian_random_test(place=core.GPUPlace(0)) + + def gaussian_random_test(self, place): + scope = core.Scope() + scope.new_var("Out").get_tensor() + + op = Operator( + "gaussian_random", + Out="Out", + dims=[1000, 784], + mean=.0, + std=1., + seed=10) + + op.infer_shape(scope) + context = core.DeviceContext.create(place) + op.run(scope, context) + tensor = numpy.array(scope.find_var("Out").get_tensor()) + self.assertAlmostEqual(numpy.mean(tensor), .0, delta=0.1) + self.assertAlmostEqual(numpy.std(tensor), 1., delta=0.1) + + +if __name__ == '__main__': + unittest.main()