提交 e2c08d28 编写于 作者: D dongzhihong

"keep style same with uniform operators"

上级 fcd6f64b
...@@ -12,42 +12,42 @@ ...@@ -12,42 +12,42 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/gaussian_random_op.h" #include <random>
#include "glog/logging.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T> class GaussianRandomKernel : public framework::OpKernel {
: public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean"); T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
auto std = context.op_.GetAttr<T>("std"); T std = static_cast<T>(context.op_.GetAttr<T>("std"));
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* tensor = context.Output<framework::Tensor>(0);
T* r = output->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
auto ctx =
static_cast<const platform::CPUDeviceContext*>(context.device_context_); // TODO(dzh): attribute does not support unsigned int.
// generator need to modify context // And we need a global random seed configuration.
auto g = const_cast<platform::CPUDeviceContext*>(ctx)->RandGenerator(); int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std); std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(output->dims()); ++i) { for (int i = 0; i < framework::product(tensor->dims()); ++i) {
r[i] = distribution(g); data[i] = distribution(g);
} }
} }
}; };
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& context) const override {
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); auto* tensor = context.Output<framework::Tensor>(0);
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(outputs[0] != nullptr, PADDLE_ENFORCE(dims.size() > 0UL,
"Outputs of RandomOp must all be set."); "dims can be one int or array. dims must be set.");
auto* tensor = ctx.Output<Tensor>(0);
auto dims = GetAttr(std::vector<int>("shape"));
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
} }
}; };
...@@ -57,26 +57,25 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,26 +57,25 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
GaussianRandomOpMaker(framework::OpProto* proto, GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0);
AddAttr<float>("std", "minimum value of random value")
.SetDefault(1.0)
.LargerThan(.0);
AddOutput("Out", "output matrix of random op"); AddOutput("Out", "output matrix of random op");
AddComment(R"DOC( AddComment(R"DOC(
GaussianRandom Operator fill a matrix in normal distribution. GaussianRandom operator.
The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std) Use to initialize tensor with gaussian random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of generator."
"0 means use system wide seed")
.SetDefault(0);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(gaussian_random, paddle::operators::GaussianRandomOp, namespace ops = paddle::operators;
paddle::operators::GaussianRandomOpMaker); REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace,
float>
GaussianRandomOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(gaussian_random, GaussianRandomOpKernel_CPU_float);
#include <memory>
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/guassian_random_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T> class GaussianRandomKernel : public framework::OpKernel {
: public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean"); T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
auto std = context.op_.GetAttr<T>("std"); T std = static_cast<T>(context.op_.GetAttr<T>("std"));
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* tensor = context.Output<framework::Tensor>(0);
T* r = output->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
auto ctx =
static_cast<const platform::GPUDeviceContext*>(context.device_context_); int seed = context.op_.GetAttr<int>("seed");
// generator need to modify context if (seed == 0) {
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator(); seed = std::random_device()();
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std); }
curandGenerator_t g;
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
std);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, namespace ops = paddle::operators;
float> REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
RandomOpKernel_GPU_float; \ No newline at end of file
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);
\ No newline at end of file
#pragma once
#include <random>
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class GaussianRandomOpKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册