提交 58730ba1 编写于 作者: Y yangyaming

Enhance unit test.

上级 bf3f56e8
...@@ -33,10 +33,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -33,10 +33,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
"Output(Out) of SequenceExpandOp should not be null."); "Output(Out) of SequenceExpandOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = x_dims;
int ref_level = ctx->Attrs().Get<int>("ref_level"); int ref_level = ctx->Attrs().Get<int>("ref_level");
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Dimension number of Input(X) should be 2."); "Dimension number of Input(X) should be at least 2.");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
framework::Variable* x_var = framework::Variable* x_var =
...@@ -50,15 +51,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -50,15 +51,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(x_lod.size(), 1, PADDLE_ENFORCE_LE(x_lod.size(), 1,
"Number of lod level of Input(X) should not be " "Number of lod level of Input(X) should not be "
"greater than 1."); "greater than 1.");
PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0,
"Level number of Input(X)'s lod should be either equal "
"to 0 or equal to that of Input(Y).");
PADDLE_ENFORCE_GT(y_lod.size(), 0, PADDLE_ENFORCE_GT(y_lod.size(), 0,
"Level number of Input(Y)'s lod should be " "Level number of Input(Y)'s lod should be "
"greater than 0."); "greater than 0.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ref_level == -1 || ref_level == -1 ||
(ref_level >= 0 && ref_level < static_cast<int>(y_lod.size())), (ref_level >= 0 && ref_level < static_cast<int>(y_lod.size())),
...@@ -68,6 +63,14 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -68,6 +63,14 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
if (ref_level == -1) ref_level = y_lod.size() - 1; if (ref_level == -1) ref_level = y_lod.size() - 1;
if (x_lod.size() > 0) {
PADDLE_ENFORCE(
x_lod.size() == 0 || x_lod[0].size() == y_lod[ref_level].size(),
"Level number of Input(X)'s lod should be 0. Otherwise "
"size of Input(X)'s first level lod should be equal to "
"size of Input(Y)'s lod of referred level.");
}
int64_t out_first_dim = 0; int64_t out_first_dim = 0;
if (y_lod[ref_level].size() <= 1) { if (y_lod[ref_level].size() <= 1) {
out_first_dim = x_dims[0]; out_first_dim = x_dims[0];
...@@ -81,9 +84,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -81,9 +84,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len; (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
} }
} }
ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]}); out_dims[0] = out_first_dim;
ctx->SetOutputDim("Out", out_dims);
} else { } else {
ctx->SetOutputDim("Out", {-1, x_dims[1]}); out_dims[0] = -1;
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
}; };
...@@ -105,69 +111,69 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,69 +111,69 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Sequence Expand Operator. Sequence Expand Operator.
This operator expands input(X) according to LOD of input(Y). This operator expands `X` according to specified level lod of `Y`. Current
implementation constaints that lod level of `X` should be at most 1. Attribute
`ref_level` is used to specify which level lod of `Y` is referred to expand `X`.
If set `ref_level` to -1, then last level lod of `Y` would be referred.
Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X`
would be viewed as a 2-D tensor.
Following are cases to better explain how this works: Following are cases to better explain how this works:
Case 1: Case 1:
Given a 2-level LoDTensor input(X) Given a 1-level LoDTensor input(X)
X.lod = [[0, 2, 3], X.lod = [[0, 2, 4]]
[0, 1, 3, 4]] X.data = [[a], [b], [c], [d]]
X.data = [a, b, c, d]
X.dims = [4, 1] X.dims = [4, 1]
and input(Y) and input(Y)
Y.lod = [[0, 2, 4], Y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]] [0, 3, 6, 7, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0] ref_level: 0
then we get 2-level LoDTensor then we get 1-level LoDTensor
Out.lod = [[0, 2, 4], Out.lod = [[0, 2, 4, 6, 8]]
[0, 3, 6, 7, 8]] Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
Out.data = [a, a, a, b, b, b, c, d]
Out.dims = [8, 1] Out.dims = [8, 1]
Case 2: Case 2:
Given 1-level LoDTensor input(X)
X.lod = [[0, 1, 4]]
X.data = [[a], [b], [c], [d]]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
ref_level: 0
then we get 1-level LoDTensor
Out.lod = [[0, 2, 5, 8]]
Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]]
Out.dims = [8, 1]
Case 3:
Given a common Tensor input(X) Given a common Tensor input(X)
X.data = [a, b, c] X.data = [[a], [b], [c]]
X.dims = [3, 1] X.dims = [3, 1]
and input(Y) and input(Y)
Y.lod = [[0, 2, 3, 6]] Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0] ref_level: -1
then we get 1-level LoDTensor then we a common Tensor
Out.lod = [[0, 2, 3, 6]] Out.data = [[a], [a], [b], [c], [c], [c]]
Out.data = [a, a, b, c, c, c]
Out.dims = [6, 1] Out.dims = [6, 1]
Case 3: Case 4:
Given a common Tensor input(X) Given a common Tensor input(X)
X.data = [[a, b], [c, d], [e, f]] X.data = [[a, b], [c, d], [e, f]]
X.dims = [3, 2] X.dims = [3, 2]
and input(Y) and input(Y)
Y.lod = [[0, 2, 3, 6]] Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0] ref_level: 0
then we get 1-level LoDTensor then we get a common LoDTensor
Out.lod = [[0, 2, 3, 6]] Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
Out.dims = [6, 2] Out.dims = [6, 2]
Case 4:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
Out.data = [a, a, a, b, b, b, d, d]
Out.dims = [8, 1]
)DOC"); )DOC");
} }
}; };
......
...@@ -22,6 +22,9 @@ namespace paddle { ...@@ -22,6 +22,9 @@ namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> { class SequenceExpandKernel : public framework::OpKernel<T> {
...@@ -30,15 +33,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -30,15 +33,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
auto* y = context.Input<LoDTensor>("Y"); auto* y = context.Input<LoDTensor>("Y");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
int ref_level = context.Attr<int>("ref_level");
out->mutable_data<T>(context.GetPlace()); int ref_level = context.Attr<int>("ref_level");
auto& x_lod = x->lod(); auto& x_lod = x->lod();
auto& y_lod = y->lod(); auto& y_lod = y->lod();
PADDLE_ENFORCE_GT(y_lod.size(), 0, PADDLE_ENFORCE_GT(y_lod.size(), 0,
"Level number of `Y`'s lod should be greater than 0."); "Level number of `Y`'s lod should be greater than 0.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()), ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()),
"Invlid `ref_level`, which should be either equal to -1 " "Invlid `ref_level`, which should be either equal to -1 "
...@@ -47,6 +47,8 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -47,6 +47,8 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
if (ref_level == -1) ref_level = y_lod.size() - 1; if (ref_level == -1) ref_level = y_lod.size() - 1;
out->mutable_data<T>(context.GetPlace());
if (y_lod[ref_level].size() <= 1) { if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*x, context.GetPlace(), out); framework::TensorCopy(*x, context.GetPlace(), out);
return; return;
...@@ -59,6 +61,8 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -59,6 +61,8 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
} }
int out_offset = 0; int out_offset = 0;
auto& eigen_place =
*context.template device_context<DeviceContext>().eigen_device();
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = i - 1; int x_start = i - 1;
...@@ -68,16 +72,24 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -68,16 +72,24 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
x_end = x_lod[0][i]; x_end = x_lod[0][i];
} }
int x_seq_len = x_end - x_start; int x_seq_len = x_end - x_start;
auto x_sub_tensor = x->Slice(x_start, x_end); if (repeat_num > 0) {
for (size_t j = 0; j < repeat_num; ++j) { auto x_sub_tensor = x->Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset; int out_start = out_offset;
if (x_lod.size() == 1) { if (x_lod.size() == 1) {
out_start = out_lod[0][out_offset]; out_start = out_lod[0][out_offset];
out_lod[0].push_back(x_seq_len);
} }
auto out_sub_tensor = out->Slice(out_start, out_start + x_seq_len); auto out_sub_tensor =
framework::TensorCopy(x_sub_tensor, context.GetPlace(), out->Slice(out_start, out_start + x_seq_len * repeat_num);
&out_sub_tensor); out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
for (int j = 0; j < repeat_num; ++j) {
if (x_lod.size() == 1) {
out_lod[0].push_back(out_lod[0].back() + x_seq_len);
}
out_offset++; out_offset++;
} }
} }
...@@ -122,6 +134,9 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> { ...@@ -122,6 +134,9 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, g_x, static_cast<T>(0));
int g_out_offset = 0; int g_out_offset = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
...@@ -133,12 +148,11 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> { ...@@ -133,12 +148,11 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
x_end = x_lod[0][i]; x_end = x_lod[0][i];
} }
int x_seq_len = x_end - x_start; int x_seq_len = x_end - x_start;
auto column = x_seq_len * x->dims()[1];
auto g_x_sub = g_x->Slice(x_start, x_end); auto g_x_sub = g_x->Slice(x_start, x_end);
g_x_sub = framework::ReshapeToMatrix(g_x_sub, column); g_x_sub.Resize(flatten_to_1d(g_x_sub.dims()));
int g_out_end = g_out_offset + repeat_num * x_seq_len; int g_out_end = g_out_offset + repeat_num * x_seq_len;
auto g_out_sub = g_out->Slice(g_out_offset, g_out_end); auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
g_out_sub = framework::ReshapeToMatrix(g_out_sub, column); g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]});
math::ColwiseSum<DeviceContext, T> col_sum; math::ColwiseSum<DeviceContext, T> col_sum;
col_sum(dev_ctx, g_out_sub, &g_x_sub); col_sum(dev_ctx, g_out_sub, &g_x_sub);
g_out_offset += repeat_num * x_seq_len; g_out_offset += repeat_num * x_seq_len;
......
...@@ -1781,52 +1781,52 @@ def conv2d_transpose(input, ...@@ -1781,52 +1781,52 @@ def conv2d_transpose(input,
return out return out
def sequence_expand(x, y, name=None): def sequence_expand(x, y, ref_level=-1, name=None):
"""Sequence Expand Layer. This layer will expand the input variable **x** """Sequence Expand Layer. This layer will expand the input variable **x**
according to LoD information of **y**. And the following examples will according to specified level lod of **y**. Please note that lod level of
explain how sequence_expand works: **x** is at most 1 and rank of **x** is at least 2. When rank of **x**
is greater than 2, then it would be viewed as a 2-D tensor.
Following examples will explain how sequence_expand works:
.. code-block:: text .. code-block:: text
* Case 1 * Case 1
x is a LoDTensor: x is a LoDTensor:
x.lod = [[0, 2, 3], x.lod = [[0, 2, 4]]
[0, 1, 3, 4]] x.data = [[a], [b], [c], [d]]
x.data = [a, b, c, d]
x.dims = [4, 1] x.dims = [4, 1]
y is a LoDTensor: y is a LoDTensor:
y.lod = [[0, 2, 4], y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]] [0, 3, 6, 7, 8]]
with condition len(y.lod[-1]) - 1 == x.dims[0] ref_level: 0
then output is a 2-level LoDTensor: then output is a 1-level LoDTensor:
out.lod = [[0, 2, 4], out.lod = [[0, 2, 4, 6, 8]]
[0, 3, 6, 7, 8]] out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
out.data = [a, a, a, b, b, b, c, d]
out.dims = [8, 1] out.dims = [8, 1]
* Case 2 * Case 2
x is a Tensor: x is a Tensor:
x.data = [a, b, c] x.data = [[a], [b], [c]]
x.dims = [3, 1] x.dims = [3, 1]
y is a LoDTensor: y is a LoDTensor:
y.lod = [[0, 2, 3, 6]] y.lod = [[0, 2, 2, 5]]
with condition len(y.lod[-1]) - 1 == x.dims[0]
then output is a 1-level LoDTensor: ref_level: -1
out.lod = [[0, 2, 3, 6]]
out.data = [a, a, b, c, c, c]
out.dims = [6, 1]
then output is a Tensor:
out.data = [[a], [a], [c], [c], [c]]
out.dims = [5, 1]
Args: Args:
x (Variable): The input variable which is a Tensor or LoDTensor. x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a LoDTensor. y (Variable): The input variable which is a LoDTensor.
ref_level (int): Lod level of `y` to be referred by `x`. If set to -1,
refer the last level of lod.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
Returns: Returns:
Variable: The expanded variable which is a LoDTensor. Variable: The expanded variable which is a LoDTensor.
...@@ -1837,14 +1837,17 @@ def sequence_expand(x, y, name=None): ...@@ -1837,14 +1837,17 @@ def sequence_expand(x, y, name=None):
x = fluid.layers.data(name='x', shape=[10], dtype='float32') x = fluid.layers.data(name='x', shape=[10], dtype='float32')
y = fluid.layers.data(name='y', shape=[10, 20], y = fluid.layers.data(name='y', shape=[10, 20],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
out = layers.sequence_expand(x=x, y=y) out = layers.sequence_expand(x=x, y=y, ref_level=0)
""" """
helper = LayerHelper('sequence_expand', input=x, **locals()) helper = LayerHelper('sequence_expand', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type='sequence_expand', inputs={'X': x, type='sequence_expand',
'Y': y}, outputs={'Out': tmp}) inputs={'X': x,
'Y': y},
outputs={'Out': tmp},
attrs={'ref_level': ref_level})
return tmp return tmp
......
...@@ -181,8 +181,8 @@ class TestBook(unittest.TestCase): ...@@ -181,8 +181,8 @@ class TestBook(unittest.TestCase):
with program_guard(program): with program_guard(program):
x = layers.data(name='x', shape=[10], dtype='float32') x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data( y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=1) name='y', shape=[10, 20], dtype='float32', lod_level=2)
self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1))
print(str(program)) print(str(program))
def test_lstm_unit(self): def test_lstm_unit(self):
......
...@@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest): ...@@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest):
def compute(self): def compute(self):
x = self.inputs['X'] x = self.inputs['X']
x_data, x_lod = x if type(x) == tuple else (x, None) x_data, x_lod = x if type(x) == tuple else (x, None)
n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0])
y_data, y_lod = self.inputs['Y'] y_data, y_lod = self.inputs['Y']
repeats = [((y_lod[-1][i + 1] - y_lod[-1][i]))
for i in range(len(y_lod[-1]) - 1)] if hasattr(self, 'attrs'):
out = x_data.repeat(repeats, axis=0) ref_level = self.attrs['ref_level']
self.outputs = {'Out': out} else:
ref_level = len(y_lod) - 1
out = np.zeros(shape=((0, ) + x_data.shape[1:]), dtype=x_data.dtype)
if x_lod is None:
x_idx = [i for i in xrange(x_data.shape[0] + 1)]
else:
x_idx = x_lod[0]
out_lod = [[0]]
for i in xrange(1, len(y_lod[ref_level])):
repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]
x_len = x_idx[i] - x_idx[i - 1]
if repeat_num > 0:
x_sub = x_data[x_idx[i - 1]:x_idx[i], :]
x_sub = np.repeat(x_sub, repeat_num, axis=0)
out = np.vstack((out, x_sub))
if x_lod is not None:
for j in xrange(repeat_num):
out_lod[0].append(out_lod[0][-1] + x_len)
if x_lod is None:
self.outputs = {'Out': out}
else:
self.outputs = {'Out': (out, out_lod)}
def setUp(self): def setUp(self):
self.op_type = 'sequence_expand' self.op_type = 'sequence_expand'
...@@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand): ...@@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand):
x_lod = [[0, 2, 5]] x_lod = [[0, 2, 5]]
y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
class TestSequenceExpandCase2(TestSequenceExpand): class TestSequenceExpandCase2(TestSequenceExpand):
...@@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand): ...@@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand):
x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32')
x_lod = [[0, 1]] x_lod = [[0, 1]]
y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32')
y_lod = [[0, 2]] y_lod = [[0, 2], [0, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
class TestSequenceExpandCase3(TestSequenceExpand): class TestSequenceExpandCase3(TestSequenceExpand):
...@@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand): ...@@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand):
class TestSequenceExpandCase4(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand):
def set_data(self): def set_data(self):
x_data = np.array( data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]
[0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( x_data = np.array(data).reshape([5, 2]).astype('float32')
[2, 5]).astype('float32') x_lod = [[0, 2, 5]]
x_lod = [[
0,
1,
2,
]]
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32')
y_lod = [[0, 1, 2], [0, 1, 2]] y_lod = [[0, 1, 2], [0, 1, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册