提交 7443b2e4 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3596 from QiJune/implement_random_function

refine random related operators
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -19,25 +16,25 @@ namespace paddle { ...@@ -19,25 +16,25 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std"); float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int. unsigned int seed =
// And we need a global random seed configuration. static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
int seed = context.op_.GetAttr<int>("seed"); std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
std::mt19937 g(seed); engine.seed(seed);
std::normal_distribution<T> distribution(mean, std); std::normal_distribution<T> dist(mean, std);
ssize_t size = framework::product(tensor->dims()); ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) { for (ssize_t i = 0; i < size; ++i) {
data[i] = distribution(g); data[i] = dist(engine);
} }
} }
}; };
...@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
...@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator. ...@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f); AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f); AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed") "0 means use system wide seed")
...@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator. ...@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker); ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include <thrust/device_ptr.h>
#include <random> #include <thrust/iterator/counting_iterator.h>
#include "paddle/platform/dynload/curand.h" #include <thrust/random.h>
#include "paddle/platform/gpu_info.h" #include <thrust/transform.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomKernel : public framework::OpKernel { struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
return dist(rng);
}
};
template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); auto* tensor = context.Output<framework::Tensor>("Out");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
int seed = context.op_.GetAttr<int>("seed"); static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
curandGenerator_t g; T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( T std = static_cast<T>(context.op_.GetAttr<float>("std"));
&g, CURAND_RNG_PSEUDO_DEFAULT)); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
PADDLE_ENFORCE( ssize_t N = framework::product(tensor->dims());
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); thrust::transform(index_sequence_begin, index_sequence_begin + N,
platform::dynload::curandGenerateNormal( thrust::device_ptr<T>(data),
g, data, framework::product(tensor->dims()), mean, std); GaussianGenerator<T>(mean, std, seed));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(gaussian_random,
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); paddle::operators::GPUGaussianRandomKernel<float>);
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")), static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max"))); static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
} }
} }
...@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output tensor of uniform random op"); AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(Uniform random operator. AddComment(R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator. Used to initialize tensor with uniform random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "the dimension of random tensor"); AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
...@@ -84,4 +81,4 @@ Used to initialize tensor with uniform random generator. ...@@ -84,4 +81,4 @@ Used to initialize tensor with uniform random generator.
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>); paddle::operators::CPUUniformRandomKernel<float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -68,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel { ...@@ -68,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel {
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(uniform_random, REGISTER_OP_GPU_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>); paddle::operators::GPUUniformRandomKernel<float>);
\ No newline at end of file
...@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
} }
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { ...@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
cudaStream_t CUDADeviceContext::stream() { return stream_; } cudaStream_t CUDADeviceContext::stream() { return stream_; }
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} // namespace platform } // namespace platform
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
...@@ -40,7 +39,7 @@ class DeviceContext { ...@@ -40,7 +39,7 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext(); CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace); explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {} virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const; Eigen::DefaultDevice* eigen_device() const;
...@@ -56,7 +55,7 @@ class EigenCudaStreamDevice; ...@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(GPUPlace); explicit CUDADeviceContext(GPUPlace place);
virtual ~CUDADeviceContext(); virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */ /*! \brief Wait for all operations completion in the stream. */
...@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle(); cudnnHandle_t cudnn_handle();
/*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator();
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream(); cudaStream_t stream();
// clang-format on // clang-format on
...@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext { ...@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
private: private:
GPUPlace place_; GPUPlace place_;
private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
private:
uint64_t seed_;
// clang-format off // clang-format off
cudaStream_t stream_{nullptr}; cudaStream_t stream_{nullptr};
cudnnHandle_t cudnn_handle_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr};
cublasHandle_t cublas_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr};
curandGenerator_t curand_generator_{nullptr};
// clang-format on // clang-format on
}; };
......
...@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle(); cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle); ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
ASSERT_NE(nullptr, device_context->stream()); ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
......
...@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) ...@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py) py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_sgd_op SRCS test_sgd_op.py)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册