提交 e2c08d28 编写于 作者: D dongzhihong

"keep style same with uniform operators"

上级 fcd6f64b
......@@ -12,42 +12,42 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gaussian_random_op.h"
#include "glog/logging.h"
#include <random>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel {
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx =
static_cast<const platform::CPUDeviceContext*>(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::CPUDeviceContext*>(ctx)->RandGenerator();
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int.
// And we need a global random seed configuration.
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < framework::product(output->dims()); ++i) {
r[i] = distribution(g);
for (int i = 0; i < framework::product(tensor->dims()); ++i) {
data[i] = distribution(g);
}
}
};
class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(inputs.size() == 0, "Input size of RandomOp must be zero.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one.");
PADDLE_ENFORCE(outputs[0] != nullptr,
"Outputs of RandomOp must all be set.");
auto* tensor = ctx.Output<Tensor>(0);
auto dims = GetAttr(std::vector<int>("shape"));
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(dims));
}
};
......@@ -57,26 +57,25 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<std::vector<int>>("shape", "The shape of matrix to be randomized");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0);
AddAttr<float>("std", "minimum value of random value")
.SetDefault(1.0)
.LargerThan(.0);
AddOutput("Out", "output matrix of random op");
AddComment(R"DOC(
GaussianRandom Operator fill a matrix in normal distribution.
The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std)
GaussianRandom operator.
Use to initialize tensor with gaussian random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of generator."
"0 means use system wide seed")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(gaussian_random, paddle::operators::GaussianRandomOp,
paddle::operators::GaussianRandomOpMaker);
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace,
float>
GaussianRandomOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(gaussian_random, GaussianRandomOpKernel_CPU_float);
namespace ops = paddle::operators;
REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
#include <memory>
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/guassian_random_op.h"
namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T>
: public framework::OpKernel {
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx =
static_cast<const platform::GPUDeviceContext*>(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
void Compute(const framework::ExecutionContext& context) const override {
T mean = static_cast<T>(context.op_.GetAttr<T>("mean"));
T std = static_cast<T>(context.op_.GetAttr<T>("std"));
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
curandGenerator_t g;
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
// auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
std);
}
};
} // namespace operators
} // namespace paddle
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace,
float>
RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);
\ No newline at end of file
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
\ No newline at end of file
#pragma once
#include <random>
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class GaussianRandomOpKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册