From 5ad9474bf7d2ad94578bd509957ae331cde36ab0 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 24 Jul 2017 10:36:10 +0800 Subject: [PATCH] add random op --- paddle/operators/CMakeLists.txt | 1 + paddle/operators/random_op.cc | 46 +++++++++++++++++++++++++++++++++ paddle/operators/random_op.cu | 6 +++++ paddle/operators/random_op.h | 29 +++++++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 paddle/operators/random_op.cc create mode 100644 paddle/operators/random_op.cu create mode 100644 paddle/operators/random_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a37720e509..14f8303c40 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -48,6 +48,7 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) +op_library(random_op SRCS random_op.cc random_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) diff --git a/paddle/operators/random_op.cc b/paddle/operators/random_op.cc new file mode 100644 index 0000000000..c219a0b67d --- /dev/null +++ b/paddle/operators/random_op.cc @@ -0,0 +1,46 @@ +#include "paddle/operators/random_op.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +class RandomOp : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector& inputs, + const std::vector& outputs) const override { + PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero."); + PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); + PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, + "Inputs/Outputs of RandomOp must all be set."); + outputs[0]->set_dims(inputs[0]->dims()); + } +}; + +class RandomOpMaker : public framework::OpProtoAndCheckerMaker { +public: + RandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr>("Shape", "The shape of matrix to be randomized"); + AddAttr("seed", "random seed generator.").SetDefault(1337); + AddAttr("mean", "mean value of random.").SetDefault(.0); + AddAttr("std", "minimum value of random value") + .SetDefault(1.0) + .LargerThan(.0); + AddOutput("Out", "output matrix of random op"); + AddComment(R"DOC( +Random Operator fill a matrix in normal distribution. +The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std) +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP(random_op, + paddle::operators::RandomOp, + paddle::operators::RandomOpMaker); + +typedef paddle::operators::RandomOpKernel + RandomOpKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(random_op, RandomOpKernel_CPU_float); diff --git a/paddle/operators/random_op.cu b/paddle/operators/random_op.cu new file mode 100644 index 0000000000..50985f6699 --- /dev/null +++ b/paddle/operators/random_op.cu @@ -0,0 +1,6 @@ +#include "paddle/operators/random_op.h" +#include "paddle/framework/op_registry.h" + +typedef paddle::operators::RandomOpKernel + RandomOpKernel_GPU_float; +REGISTER_OP_GPU_KERNEL(random_op, RandomOpKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/random_op.h b/paddle/operators/random_op.h new file mode 100644 index 0000000000..123d9d6ffa --- /dev/null +++ b/paddle/operators/random_op.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { +template +class RandomOpKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& context) const override { + auto* output = context.Output(0)->GetMutable(); + output->mutable_data(context.GetPlace()); + + auto shape = context.op_.attrs_.at("Shape"); + auto mean = context.op_.attrs_.at("mean"); + auto std = context.op_.attrs_.at("std"); + auto seed = context.op_.attrs_.at("seed"); + // std::default_random_engine generator(seed); + // std::normal_distribution distribution(mean, std); + + framework::EigenMatrix::From(*output).device(*( + context.GetEigenDevice())) = framework::EigenMatrix::Random(); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab