未验证 提交 0d29cf18 编写于 作者: A Aurelius84 提交者: GitHub

Supports diagonal initialization in uniform_random op (#19299)

* add diag init in Uniform_random op test=develop

* modify api.spec test=develop

* fix unform_batch_size_like maker test=develop

* add diag_num and diag_step assert check test=develop
上级 5a579df9
......@@ -92,8 +92,8 @@ paddle.fluid.io.Fake ('paddle.reader.decorator.Fake', ('document', '0d8f4847b99b
paddle.fluid.io.Fake.__init__ (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.initializer.ConstantInitializer ('paddle.fluid.initializer.ConstantInitializer', ('document', '798f1fd87cbe9798d001ffb6e616415d'))
paddle.fluid.initializer.ConstantInitializer.__init__ (ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.initializer.UniformInitializer ('paddle.fluid.initializer.UniformInitializer', ('document', 'a8f1177e4ce29766853e801d5b0a3635'))
paddle.fluid.initializer.UniformInitializer.__init__ (ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.initializer.UniformInitializer ('paddle.fluid.initializer.UniformInitializer', ('document', '587b7035cd1d56f76f2ded617b92521d'))
paddle.fluid.initializer.UniformInitializer.__init__ (ArgSpec(args=['self', 'low', 'high', 'seed', 'diag_num', 'diag_step', 'diag_val'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0, 0, 0, 1.0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.initializer.NormalInitializer ('paddle.fluid.initializer.NormalInitializer', ('document', '279a0d89bf01138fbf4c4ba14f22099b'))
paddle.fluid.initializer.NormalInitializer.__init__ (ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.initializer.TruncatedNormalInitializer ('paddle.fluid.initializer.TruncatedNormalInitializer', ('document', 'b8e90aad6ee5687cb5f2b6fd404370d1'))
......
......@@ -56,6 +56,14 @@ with random values sampled from a uniform distribution.
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("diag_num",
"The number of diag elements. Note that if "
"diag_num is 0, it means without diag init.[default 0].")
.SetDefault(0);
AddAttr<int>("diag_step", "The step between two diag element.[default 0].")
.SetDefault(0);
AddAttr<float>("diag_val", "The value of diag element. [default 1.0].")
.SetDefault(1.0f);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::VarType::FP32);
}
......
......@@ -53,6 +53,19 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
if (diag_num > 0) {
PADDLE_ENFORCE_GT(size, (diag_num - 1) * (diag_step + 1),
"The index of diagonal elements is out of bounds");
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data[pos] = diag_val;
}
}
}
};
......@@ -61,13 +74,17 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniformRandomOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of UniformRandomOp should not be null.");
PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
"diag_num must greater than or equal 0");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
"diag_step must greater than or equal 0");
std::vector<int64_t> temp;
temp.reserve(shape.size());
for (auto dim : shape) {
......@@ -105,6 +122,14 @@ uniform distribution. The random result is in set [min, max].
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0].")
.SetDefault(0);
AddAttr<int>("diag_num",
"The number of diag elements. Note that if "
"diag_num is 0, it means without diag init.[default 0].")
.SetDefault(0);
AddAttr<int>("diag_step", "The step between two diag element.[default 0].")
.SetDefault(0);
AddAttr<float>("diag_val", "The value of diag element. [default 1.0].")
.SetDefault(1.0f);
AddAttr<int>("dtype", "Output tensor data type. [default 5(FP32)].")
.SetDefault(framework::proto::VarType::FP32);
}
......
......@@ -23,16 +23,29 @@ template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
__host__ __device__ UniformGenerator(T min, T max, int seed)
: min_(min), max_(max), seed_(seed) {}
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
return dist(rng);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
......@@ -64,11 +77,17 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(context.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(context.Attr<int>("diag_step"));
T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed));
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
};
......
......@@ -208,6 +208,12 @@ class UniformInitializer(Initializer):
low (float): lower boundary of the uniform distribution
high (float): upper boundary of the uniform distribution
seed (int): random seed
diag_num (int): the number of diagonal elements to initialize.
If set to 0, diagonal initialization will be not performed.
diag_step (int): Step size between two diagonal elements,
which is generally the width of the square matrix.
diag_val (float): the value of the diagonal element to be initialized,
default 1.0. It takes effect only if the diag_num is greater than 0.
Examples:
.. code-block:: python
......@@ -218,15 +224,29 @@ class UniformInitializer(Initializer):
param_attr=fluid.initializer.Uniform(low=-0.5, high=0.5))
"""
def __init__(self, low=-1.0, high=1.0, seed=0):
def __init__(self,
low=-1.0,
high=1.0,
seed=0,
diag_num=0,
diag_step=0,
diag_val=1.0):
assert low is not None
assert high is not None
assert high >= low
assert seed is not None
assert diag_num is not None
assert diag_step is not None
assert diag_val is not None
if diag_num > 0 or diag_step > 0:
assert (diag_num > 0 and diag_step > 0)
super(UniformInitializer, self).__init__()
self._low = low
self._high = high
self._seed = seed
self._diag_num = diag_num
self._diag_step = diag_step
self._diag_val = diag_val
def __call__(self, var, block):
"""Add uniform distribution initialization ops for a variable
......@@ -267,7 +287,10 @@ class UniformInitializer(Initializer):
"dtype": out_dtype,
"min": self._low,
"max": self._high,
"seed": self._seed
"seed": self._seed,
"diag_num": self._diag_num,
"diag_step": self._diag_step,
"diag_val": self._diag_val
},
stop_gradient=True)
......
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
def output_hist(out):
......@@ -29,28 +30,59 @@ def output_hist(out):
return hist, prob
def output_hist_diag(out):
diag_num = min(out.shape)
for i in range(diag_num):
assert abs(out[i][i] - 1.0) < 1e-9
# ignore diagonal elements
out[i][i] = 100
hist, _ = np.histogram(out, range=(-5, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestUniformRandomOp(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.inputs = {}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {
"shape": [1000, 784],
"min": -5.0,
"max": 10.0,
"seed": 10
}
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = output_hist(np.array(outs[0]))
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpWithDiagInit(TestUniformRandomOp):
def init_attrs(self):
self.attrs = {
"shape": [1000, 784],
"min": -5.0,
"max": 10.0,
"seed": 10,
"diag_num": 784,
"diag_step": 784,
"diag_val": 1.0
}
self.output_hist = output_hist_diag
class TestUniformRandomOpSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
......@@ -81,5 +113,50 @@ class TestUniformRandomOpSelectedRows(unittest.TestCase):
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpSelectedRowsWithDiagInit(
TestUniformRandomOpSelectedRows):
def check_with_place(self, place):
scope = core.Scope()
out = scope.var("X").get_selected_rows()
op = Operator(
"uniform_random",
Out="X",
shape=[4, 784],
min=-5.0,
max=10.0,
seed=10,
diag_num=4,
diag_step=784,
diag_val=1.0)
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [4, 784])
hist, prob = output_hist_diag(np.array(out.get_tensor()))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpApi(unittest.TestCase):
def test_api(self):
x = fluid.layers.data('x', shape=[16], dtype='float32', lod_level=1)
y = fluid.layers.fc(x,
size=16,
param_attr=fluid.initializer.Uniform(
low=-0.5,
high=0.5,
seed=10,
diag_num=16,
diag_step=16,
diag_val=1.0))
place = fluid.CPUPlace()
x_tensor = fluid.create_lod_tensor(
np.random.rand(3, 16).astype("float32"), [[1, 2]], place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(feed={'x': x_tensor}, fetch_list=[y], return_numpy=False)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册