提交 57213340 编写于 作者: D dongzhihong

"update the compute kernel"

上级 a22567eb
......@@ -88,7 +88,7 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
platform::DeviceContext& dev_ctx) const = 0;
const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
......@@ -113,7 +113,7 @@ class OperatorBase {
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
platform::DeviceContext& device_context)
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(&device_context) {}
const Variable* Input(int index) const {
......@@ -159,7 +159,7 @@ class KernelContext {
const OperatorBase& op_;
const std::shared_ptr<Scope> scope_;
platform::DeviceContext* device_context_;
const platform::DeviceContext* device_context_;
};
class OpKernel {
......@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope,
platform::DeviceContext& dev_ctx) const final {
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
}
......
......@@ -19,7 +19,28 @@
namespace paddle {
namespace operators {
class RandomOp : public framework::OperatorWithKernel {
template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
// auto seed = context.op_.GetAttr<T>("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx =
static_cast<const platform::CPUDeviceContext*>(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::CPUDeviceContext*>(ctx)->RandGenerator();
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(output->dims()); ++i) {
r[i] = distribution(g);
}
}
};
class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor*>& inputs,
......@@ -33,20 +54,21 @@ protected:
}
};
class RandomOpMaker : public framework::OpProtoAndCheckerMaker {
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized");
AddAttr<float>("seed", "random seed generator.").SetDefault(1337);
// AddAttr<float>("seed", "random seed generator.").SetDefault(1337);
AddAttr<float>("mean", "mean value of random.").SetDefault(.0);
AddAttr<float>("std", "minimum value of random value")
.SetDefault(1.0)
.LargerThan(.0);
AddOutput("Out", "output matrix of random op");
AddComment(R"DOC(
Random Operator fill a matrix in normal distribution.
The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std)
GaussianRandom Operator fill a matrix in normal distribution.
The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std)
)DOC");
}
};
......@@ -54,10 +76,11 @@ The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std)
} // namespace operators
} // namespace paddle
REGISTER_OP(random,
paddle::operators::RandomOp,
paddle::operators::RandomOpMaker);
REGISTER_OP(gaussian_random,
paddle::operators::GaussianRandomOp,
paddle::operators::GaussianRandomOpMaker);
typedef paddle::operators::RandomOpKernel<paddle::platform::CPUPlace, float>
RandomOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(random, RandomOpKernel_CPU_float);
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace,
float>
GaussianRandomOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(gaussian_random, GaussianRandomOpKernel_CPU_float);
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
typedef paddle::operators::RandomOpKernel<paddle::platform::GPUPlace, float>
template<typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T> : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx = static_cast<const platform::GPUDeviceContext*>
(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
}
};
} // namespace operators
} // namespace paddle
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, float>
RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(random, RandomOpKernel_GPU_float);
\ No newline at end of file
......@@ -7,63 +7,10 @@
namespace paddle {
namespace operators {
template <typename T>
bool Gaussian(platform::CPUDeviceContext* ctx,
T* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
auto g = ctx->RandGenerator(seed);
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < size; ++i) {
output[i] = distribution(g);
}
return true;
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
bool Gaussian(platform::CUDADeviceContext* ctx,
T* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
auto g = ctx->RandGenerator(seed);
return curandGenerateNormal(g, output, size, mean, std);
}
#endif
template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel {
class GaussianRandomOpKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto seed = context.op_.GetAttr<T>("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto place = context.GetPlace();
if (platform::is_cpu_place(place)) {
Gaussian(
dynamic_cast<platform::CPUDeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
} else {
#ifndef PADDLE_ONLY_CPU
Gaussian(
dynamic_cast<platform::CUDADeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
#endif
}
}
void Compute(const framework::KernelContext& context) const override {}
};
} // namespace operators
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <chrono>
#include <memory>
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -40,7 +41,10 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {
public:
typedef std::mt19937 random_generator_type;
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
CPUDeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
......@@ -49,16 +53,15 @@ class CPUDeviceContext : public DeviceContext {
return retv;
}
random_generator_type& RandGenerator(const int seed) {
random_generator_type& RandGenerator() {
if (!rand_generator_) {
random_seed_ = seed;
rand_generator_.reset(new random_generator_type(random_seed_));
}
return *rand_generator_.get();
}
private:
int random_seed_;
unsigned random_seed_;
std::unique_ptr<random_generator_type> rand_generator_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
......@@ -81,6 +84,9 @@ class GPUPlaceGuard {
class CUDADeviceContext : public DeviceContext {
public:
CUDADeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
}
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
......@@ -98,9 +104,8 @@ class CUDADeviceContext : public DeviceContext {
"cudaStreamSynchronize failed");
}
curandGenerator_t RandGenerator(const int seed) {
curandGenerator_t RandGenerator() {
if (!rand_generator_) {
random_seed_ = seed;
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
......@@ -177,7 +182,7 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t dnn_handle_{nullptr};
int random_seed_;
unsigned random_seed_;
curandGenerator_t rand_generator_{nullptr};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册