提交 57213340 编写于 作者: D dongzhihong

"update the compute kernel"

上级 a22567eb
...@@ -88,7 +88,7 @@ class OperatorBase { ...@@ -88,7 +88,7 @@ class OperatorBase {
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto` // Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
...@@ -113,7 +113,7 @@ class OperatorBase { ...@@ -113,7 +113,7 @@ class OperatorBase {
class KernelContext { class KernelContext {
public: public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
platform::DeviceContext& device_context) const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(&device_context) {} : op_(*op), scope_(scope), device_context_(&device_context) {}
const Variable* Input(int index) const { const Variable* Input(int index) const {
...@@ -159,7 +159,7 @@ class KernelContext { ...@@ -159,7 +159,7 @@ class KernelContext {
const OperatorBase& op_; const OperatorBase& op_;
const std::shared_ptr<Scope> scope_; const std::shared_ptr<Scope> scope_;
platform::DeviceContext* device_context_; const platform::DeviceContext* device_context_;
}; };
class OpKernel { class OpKernel {
...@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx));
} }
......
...@@ -19,7 +19,28 @@ ...@@ -19,7 +19,28 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class RandomOp : public framework::OperatorWithKernel { template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
// auto seed = context.op_.GetAttr<T>("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx =
static_cast<const platform::CPUDeviceContext*>(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::CPUDeviceContext*>(ctx)->RandGenerator();
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(output->dims()); ++i) {
r[i] = distribution(g);
}
}
};
class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(
const std::vector<const framework::Tensor*>& inputs, const std::vector<const framework::Tensor*>& inputs,
...@@ -33,20 +54,21 @@ protected: ...@@ -33,20 +54,21 @@ protected:
} }
}; };
class RandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized"); AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized");
AddAttr<float>("seed", "random seed generator.").SetDefault(1337); // AddAttr<float>("seed", "random seed generator.").SetDefault(1337);
AddAttr<float>("mean", "mean value of random.").SetDefault(.0); AddAttr<float>("mean", "mean value of random.").SetDefault(.0);
AddAttr<float>("std", "minimum value of random value") AddAttr<float>("std", "minimum value of random value")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(.0); .LargerThan(.0);
AddOutput("Out", "output matrix of random op"); AddOutput("Out", "output matrix of random op");
AddComment(R"DOC( AddComment(R"DOC(
Random Operator fill a matrix in normal distribution. GaussianRandom Operator fill a matrix in normal distribution.
The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std) The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std)
)DOC"); )DOC");
} }
}; };
...@@ -54,10 +76,11 @@ The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std) ...@@ -54,10 +76,11 @@ The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(random, REGISTER_OP(gaussian_random,
paddle::operators::RandomOp, paddle::operators::GaussianRandomOp,
paddle::operators::RandomOpMaker); paddle::operators::GaussianRandomOpMaker);
typedef paddle::operators::RandomOpKernel<paddle::platform::CPUPlace, float> typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace,
RandomOpKernel_CPU_float; float>
REGISTER_OP_CPU_KERNEL(random, RandomOpKernel_CPU_float); GaussianRandomOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(gaussian_random, GaussianRandomOpKernel_CPU_float);
#include "paddle/operators/random_op.h" #include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template<typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T> : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx = static_cast<const platform::GPUDeviceContext*>
(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
typedef paddle::operators::RandomOpKernel<paddle::platform::GPUPlace, float> }
};
} // namespace operators
} // namespace paddle
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, float>
RandomOpKernel_GPU_float; RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(random, RandomOpKernel_GPU_float); REGISTER_OP_GPU_KERNEL(random, RandomOpKernel_GPU_float);
\ No newline at end of file
...@@ -7,63 +7,10 @@ ...@@ -7,63 +7,10 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
bool Gaussian(platform::CPUDeviceContext* ctx,
T* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
auto g = ctx->RandGenerator(seed);
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < size; ++i) {
output[i] = distribution(g);
}
return true;
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
bool Gaussian(platform::CUDADeviceContext* ctx,
T* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
auto g = ctx->RandGenerator(seed);
return curandGenerateNormal(g, output, size, mean, std);
}
#endif
template <typename Place, typename T> template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel { class GaussianRandomOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {}
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto seed = context.op_.GetAttr<T>("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto place = context.GetPlace();
if (platform::is_cpu_place(place)) {
Gaussian(
dynamic_cast<platform::CPUDeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
} else {
#ifndef PADDLE_ONLY_CPU
Gaussian(
dynamic_cast<platform::CUDADeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
#endif
}
}
}; };
} // namespace operators } // namespace operators
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include <chrono>
#include <memory> #include <memory>
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
...@@ -40,7 +41,10 @@ class DeviceContext { ...@@ -40,7 +41,10 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
typedef std::mt19937 random_generator_type; typedef std::mt19937 random_generator_type;
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } CPUDeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
...@@ -49,16 +53,15 @@ class CPUDeviceContext : public DeviceContext { ...@@ -49,16 +53,15 @@ class CPUDeviceContext : public DeviceContext {
return retv; return retv;
} }
random_generator_type& RandGenerator(const int seed) { random_generator_type& RandGenerator() {
if (!rand_generator_) { if (!rand_generator_) {
random_seed_ = seed;
rand_generator_.reset(new random_generator_type(random_seed_)); rand_generator_.reset(new random_generator_type(random_seed_));
} }
return *rand_generator_.get(); return *rand_generator_.get();
} }
private: private:
int random_seed_; unsigned random_seed_;
std::unique_ptr<random_generator_type> rand_generator_; std::unique_ptr<random_generator_type> rand_generator_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
...@@ -81,6 +84,9 @@ class GPUPlaceGuard { ...@@ -81,6 +84,9 @@ class GPUPlaceGuard {
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
CUDADeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
}
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
...@@ -98,9 +104,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -98,9 +104,8 @@ class CUDADeviceContext : public DeviceContext {
"cudaStreamSynchronize failed"); "cudaStreamSynchronize failed");
} }
curandGenerator_t RandGenerator(const int seed) { curandGenerator_t RandGenerator() {
if (!rand_generator_) { if (!rand_generator_) {
random_seed_ = seed;
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
...@@ -177,7 +182,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -177,7 +182,7 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t dnn_handle_{nullptr}; cudnnHandle_t dnn_handle_{nullptr};
int random_seed_; unsigned random_seed_;
curandGenerator_t rand_generator_{nullptr}; curandGenerator_t rand_generator_{nullptr};
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册