提交 0d554f1d 编写于 作者: D dongzhihong

"add template fill function"

上级 c110f565
...@@ -3,6 +3,18 @@ ...@@ -3,6 +3,18 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using paddle::platform::GPUPlace;
template <GPUPlace, typename T, typename Generator>
bool Gaussian(
Generator g, T* output, const int size, const T& mean, const T& std) {
std::normal_distribution<double> distribution(mean, std);
for (int i = 0; i < size; ++i) {
output[i] = distribution(g());
}
return true;
}
class RandomOp : public framework::OperatorWithKernel { class RandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(
...@@ -12,7 +24,7 @@ protected: ...@@ -12,7 +24,7 @@ protected:
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one.");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of RandomOp must all be set."); "Inputs/Outputs of RandomOp must all be set.");
outputs[0]->set_dims(inputs[0]->dims()); outputs[0]->set_dims(context.op_.attrs_.at("shape"));
} }
}; };
......
#include "paddle/operators/random_op.h" #include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using paddle::platform::GPUPlace;
template<GPUPlace, typename T, typename Generator>
bool Gaussian(Generator g, T* output, const int size, const T& mean, const T& std) {
return curandGenerateNormal(g, output, size, mean, std);
}
} // operators
} // paddle
typedef paddle::operators::RandomOpKernel<paddle::platform::GPUPlace, float> typedef paddle::operators::RandomOpKernel<paddle::platform::GPUPlace, float>
RandomOpKernel_GPU_float; RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(random_op, RandomOpKernel_GPU_float); REGISTER_OP_GPU_KERNEL(random_op, RandomOpKernel_GPU_float);
\ No newline at end of file
...@@ -6,24 +6,52 @@ ...@@ -6,24 +6,52 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename Generator>
bool Gaussian(
Generator g, T* output, const int size, const T& mean, const T& std);
template <typename Place, typename T> template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel { class RandomOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
auto shape = context.op_.attrs_.at("Shape");
auto mean = context.op_.attrs_.at("mean"); auto mean = context.op_.attrs_.at("mean");
auto std = context.op_.attrs_.at("std"); auto std = context.op_.attrs_.at("std");
auto seed = context.op_.attrs_.at("seed"); auto seed = context.op_.attrs_.at("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
Gaussian<Place, T, >(, output, output->size(), mean, std) :
// std::default_random_engine generator(seed); // std::default_random_engine generator(seed);
// std::normal_distribution<double> distribution(mean, std); // std::normal_distribution<double> distribution(mean, std);
framework::EigenMatrix<T>::From(*output).device(*( // framework::EigenMatrix<T>::From(*output).device(*(
context.GetEigenDevice<Place>())) = framework::EigenMatrix<T>::Random(); // context.GetEigenDevice<Place>())) =
// framework::EigenMatrix<T>::Random();
} }
}; };
// using paddle::platform::CPUPlace;
// template<CPUPlace, typename T>
// class RandomOpKernel : public framework::OpKernel {
// public:
// void Compute(const framework::KernelContext& context) const override {
// std::unique_ptr<default_random_engine> generator(seed);
// for(size_t i=0; i < output->size(); ++i) {
// output[i] = distribution(generator());
// }
// }
// };
// using paddle::platform::GPUPlace;
// template<GPUPlace, typename T>
// class RandomOpKernel : public framework::OpKernel {
// public:
// void Compute(const framework::KernelContext& context) const override {
// }
// }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册