提交 0c37705d 编写于 作者: Y Yu Yang

Use thrust to implement uniform_random

上级 fd0bdb4f
...@@ -49,5 +49,4 @@ Used to initialize tensor with uniform random generator. ...@@ -49,5 +49,4 @@ Used to initialize tensor with uniform random generator.
} // namespace paddle } // namespace paddle
REGISTER_OP(uniform_random, ops::RandomOp, ops::RandomOpMaker); REGISTER_OP(uniform_random, ops::RandomOp, ops::RandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random, ops::CPUUniformRandomKernel<float>);
ops::UniformRandomKernel<ops::CPUPlace, float>);
...@@ -12,7 +12,54 @@ ...@@ -12,7 +12,54 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/uniform_random_op.h" #include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/type_alias.h"
REGISTER_OP_GPU_KERNEL(uniform_random, namespace paddle {
ops::UniformRandomKernel<ops::GPUPlace, float>); namespace operators {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
__host__ __device__ UniformGenerator(T min, T max, int seed)
: min_(min), max_(max), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
return dist(rng);
}
};
template <typename T>
class GPUUniformRandomKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto* tensor = context.Output<Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) {
seed = std::random_device()();
}
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(uniform_random, ops::GPUUniformRandomKernel<float>);
...@@ -13,25 +13,30 @@ ...@@ -13,25 +13,30 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <random>
#include <type_traits>
#include "paddle/operators/type_alias.h" #include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename T>
class UniformRandomKernel : public OpKernel { class CPUUniformRandomKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext &context) const override { void Compute(const ExecutionContext& context) const override {
auto tensor = context.Output<Tensor>(0); auto* tensor = context.Output<Tensor>(0);
tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
auto eigenTensor = EigenVector<T>::Flatten(*tensor); static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
auto dev = context.GetEigenDevice<Place>(); std::minstd_rand engine;
auto min = context.op_.GetAttr<float>("min"); if (seed == 0) {
auto max = context.op_.GetAttr<float>("max"); seed = std::random_device()();
auto seed = static_cast<uint64_t>(context.op_.GetAttr<int>("seed")); }
auto diff = max - min; engine.seed(seed);
Eigen::internal::UniformRandomGenerator<T> gen(seed); std::uniform_real_distribution<T> dist(static_cast<T>(context.op_.GetAttr<float>("min")),
eigenTensor.device(dev) = eigenTensor.random(gen) * diff + min; static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) {
data[i] = dist(engine);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册