提交 5ad9474b 编写于 作者: D dongzhihong

add random op

上级 a6043c1d
......@@ -48,6 +48,7 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(random_op SRCS random_op.cc random_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net)
......
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class RandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor*>& inputs,
const std::vector<framework::Tensor*>& outputs) const override {
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one.");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of RandomOp must all be set.");
outputs[0]->set_dims(inputs[0]->dims());
}
};
class RandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<std::vector<int>>("Shape", "The shape of matrix to be randomized");
AddAttr<float>("seed", "random seed generator.").SetDefault(1337);
AddAttr<float>("mean", "mean value of random.").SetDefault(.0);
AddAttr<float>("std", "minimum value of random value")
.SetDefault(1.0)
.LargerThan(.0);
AddOutput("Out", "output matrix of random op");
AddComment(R"DOC(
Random Operator fill a matrix in normal distribution.
The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std)
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(random_op,
paddle::operators::RandomOp,
paddle::operators::RandomOpMaker);
typedef paddle::operators::RandomOpKernel<paddle::platform::CPUPlace, float>
RandomOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(random_op, RandomOpKernel_CPU_float);
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::RandomOpKernel<paddle::platform::GPUPlace, float>
RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(random_op, RandomOpKernel_GPU_float);
\ No newline at end of file
#pragma once
#include <random>
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class RandomOpKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
auto shape = context.op_.attrs_.at("Shape");
auto mean = context.op_.attrs_.at("mean");
auto std = context.op_.attrs_.at("std");
auto seed = context.op_.attrs_.at("seed");
// std::default_random_engine generator(seed);
// std::normal_distribution<double> distribution(mean, std);
framework::EigenMatrix<T>::From(*output).device(*(
context.GetEigenDevice<Place>())) = framework::EigenMatrix<T>::Random();
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册