提交 e376bda4 编写于 作者: Y Yu Yang

Add uniform random operator

It can be run both CPU/GPU. configure attributes are:

* min: the min value of uniform random
* max: the max value of uniform random
* dims: the dimension of output tensor
* seed: the random seed of uniform random. 0 means generate a seed each
        time.
上级 fd64369f
......@@ -38,9 +38,10 @@ cc_test(backward_test SRCS backward_test.cc DEPS backward)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op)
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op
uniform_random_op)
......@@ -41,6 +41,7 @@ USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(uniform_random);
namespace paddle {
namespace framework {
template <typename ClassType>
......
......@@ -66,3 +66,5 @@ op_library(fc_op
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS op_desc tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/uniform_random_op.h"
namespace paddle {
namespace operators {
class RandomOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
"uniform_random's min must less then max");
auto tensor = ctx.Output<Tensor>(0);
auto dims = GetAttr<std::vector<int>>("dims");
tensor->Resize(framework::make_ddim(dims));
}
};
class RandomOpMaker : public OpProtoAndCheckerMaker {
public:
RandomOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of uniform random. "
"0 means generate a seed by system")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(uniform_random, ops::RandomOp, ops::RandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random,
ops::UniformRandomKernel<ops::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/uniform_random_op.h"
REGISTER_OP_GPU_KERNEL(uniform_random,
ops::UniformRandomKernel<ops::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class UniformRandomKernel : public OpKernel {
public:
void Compute(const ExecutionContext &context) const override {
auto tensor = context.Output<Tensor>(0);
tensor->mutable_data<T>(context.GetPlace());
auto eigenTensor = EigenVector<T>::Flatten(*tensor);
auto dev = context.GetEigenDevice<Place>();
auto min = context.op_.GetAttr<float>("min");
auto max = context.op_.GetAttr<float>("max");
auto seed = static_cast<uint64_t>(context.op_.GetAttr<int>("seed"));
auto diff = max - min;
Eigen::internal::UniformRandomGenerator<T> gen(seed);
eigenTensor.device(dev) = eigenTensor.random(gen) * diff + min;
}
};
} // namespace operators
} // namespace paddle
......@@ -14,4 +14,5 @@ add_python_test(test_framework
test_softmax_op.py
test_rowwise_add_op.py
test_network.py
gradient_checker.py)
gradient_checker.py
test_uniform_random_op.py)
import unittest
from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
import numpy
class UniformRandomTest(unittest.TestCase):
def test_uniform_random_cpu(self):
self.uniform_random_test(place=core.CPUPlace())
def test_uniform_random_gpu(self):
if core.is_compile_gpu():
self.uniform_random_test(place=core.GPUPlace(0))
def uniform_random_test(self, place):
scope = core.Scope()
scope.new_var("X").get_tensor()
op = Operator(
"uniform_random",
Out="X",
dims=[1000, 784],
min=-5.0,
max=10.0,
seed=10)
op.infer_shape(scope)
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
tensor = numpy.array(scope.find_var("X").get_tensor())
self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册