未验证 提交 fcf53e55 编写于 作者: A Aurelius84 提交者: GitHub

support 2-level lod of input in sequence_pool (#19839)

* support 2-level lod of input in sequence_pool test=develop

* fix lod level bug in .cu test=develop
上级 b25d1e75
......@@ -37,18 +37,23 @@ class MaxSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output, framework::Tensor* index) {
framework::LoDTensor* output, framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1);
PADDLE_ENFORCE_GT(out_dims.size(), 1);
PADDLE_ENFORCE_GT(in_dims.size(), 1,
"The rank of input shall be greater than 1.");
PADDLE_ENFORCE_GT(out_dims.size(), 1,
"The rank of output shall be greater than 1.");
for (int64_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i],
"The dimension of input and output shall be same.");
}
PADDLE_ENFORCE_EQ(idx_dims, out_dims);
PADDLE_ENFORCE_EQ(idx_dims, out_dims,
"The dimension of index and output shall be same.");
auto starts = input.lod()[0];
auto lod_level = input.lod().size();
auto starts = input.lod()[lod_level - 1];
const T* in_data = input.data<T>();
T* out_data = output->data<T>();
int* max_index = index->data<int>();
......@@ -85,16 +90,20 @@ class MaxSeqPoolFunctor<T, true> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output, framework::Tensor* index) {
framework::LoDTensor* output, framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1);
PADDLE_ENFORCE_GT(out_dims.size(), 1);
PADDLE_ENFORCE_GT(in_dims.size(), 1,
"The rank of input shall be greater than 1.");
PADDLE_ENFORCE_GT(out_dims.size(), 1,
"The rank of output shall be greater than 1.");
for (int64_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i],
"The dimension of input and output shall be same.");
}
auto starts = input.lod()[0];
auto lod_level = input.lod().size();
auto starts = input.lod()[lod_level - 1];
const T* in_data = input.data<T>();
T* out_data = output->data<T>();
......@@ -123,18 +132,23 @@ template <typename T>
class MaxSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
const framework::LoDTensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad) {
auto og_dims = out_grad.dims();
auto ig_dims = in_grad->dims();
auto idx_dims = index.dims();
PADDLE_ENFORCE_GT(og_dims.size(), 1);
PADDLE_ENFORCE_GT(ig_dims.size(), 1);
PADDLE_ENFORCE_GT(og_dims.size(), 1,
"The rank of output@Grad shall be greater than 1.");
PADDLE_ENFORCE_GT(ig_dims.size(), 1,
"The rank of input@Grad shall be greater than 1.");
for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
PADDLE_ENFORCE_EQ(
og_dims[i], ig_dims[i],
"The dimension of input@Grad and output@Grad shall be same.");
}
PADDLE_ENFORCE_EQ(idx_dims, og_dims);
PADDLE_ENFORCE_EQ(idx_dims, og_dims,
"The dimension of index and output@Grad shall be same.");
const T* og_data = out_grad.data<T>();
const int* max_index = index.data<int>();
......@@ -159,14 +173,15 @@ class LastSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output) {
framework::LoDTensor* output) {
// Create pointers to input and output data
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
// Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
auto lod_level = input.lod().size();
auto lod = input.lod()[lod_level - 1];
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
......@@ -191,14 +206,15 @@ class FirstSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, T pad_value,
framework::Tensor* output) {
framework::LoDTensor* output) {
// Create pointers to input and output data
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
// Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
auto lod_level = input.lod().size();
auto lod = input.lod()[lod_level - 1];
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
......@@ -222,12 +238,15 @@ template <typename T>
class SumSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
const framework::LoDTensor& out_grad,
framework::LoDTensor* in_grad) {
auto lod = in_grad->lod()[0];
auto lod_level = in_grad->lod().size();
auto lod = in_grad->lod()[lod_level - 1];
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE(in_w == out_w);
PADDLE_ENFORCE_EQ(
in_w, out_w,
"The feature size of input@Grad and output@Grad shall be same.");
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
......@@ -250,8 +269,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
/* max pool has index output */
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, T pad_value,
const framework::LoDTensor& input, framework::Tensor* output,
bool is_test, framework::Tensor* index = nullptr) {
const framework::LoDTensor& input,
framework::LoDTensor* output, bool is_test,
framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
if (is_test) {
math::MaxSeqPoolFunctor<T, true> max_pool;
......@@ -272,11 +292,13 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
first_pool(context, input, pad_value, output);
return;
}
auto lod = input.lod()[0];
auto lod_level = input.lod().size();
auto lod = input.lod()[lod_level - 1];
if (pooltype == "SUM") {
auto place = context.GetPlace();
PADDLE_ENFORCE(platform::is_cpu_place(place));
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(place), true,
"Sequence_pool should run on CPU Device when pooltype is SUM");
const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr(
......@@ -330,7 +352,8 @@ template <typename T>
class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::Tensor& out_grad,
const std::string pooltype,
const framework::LoDTensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr) {
......@@ -352,7 +375,8 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
return;
}
auto lod = in_grad->lod()[0];
auto lod_level = in_grad->lod().size();
auto lod = in_grad->lod()[lod_level - 1];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
if (lod[i] == lod[i + 1]) continue;
......
......@@ -159,9 +159,11 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::string pooltype, T pad_value,
const framework::LoDTensor& input, framework::Tensor* output,
bool is_test, framework::Tensor* index = nullptr) {
auto& lod = input.lod()[0];
const framework::LoDTensor& input,
framework::LoDTensor* output, bool is_test,
framework::Tensor* index = nullptr) {
auto lod_level = input.lod().size();
auto& lod = input.lod()[lod_level - 1];
const size_t item_dim = output->numel() / output->dims()[0];
dim3 threads(1024, 1);
dim3 grid(lod.size(), 1);
......@@ -319,11 +321,13 @@ template <typename T>
class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::string pooltype, const framework::Tensor& out_grad,
const std::string pooltype,
const framework::LoDTensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr) {
auto& lod = in_grad->lod()[0];
auto lod_level = in_grad->lod().size();
auto& lod = in_grad->lod()[lod_level - 1];
const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
dim3 threads(1024, 1);
dim3 grid(lod.size(), 1);
......
......@@ -28,7 +28,7 @@ class SequencePoolFunctor {
/* max pool has index output */
void operator()(const DeviceContext& context, const std::string pooltype,
T pad_value, const framework::LoDTensor& input,
framework::Tensor* output, bool is_test = false,
framework::LoDTensor* output, bool is_test = false,
framework::Tensor* index = nullptr);
};
......@@ -36,7 +36,7 @@ template <typename DeviceContext, typename T>
class SequencePoolGradFunctor {
public:
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::Tensor& out_grad,
const framework::LoDTensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr);
......
......@@ -24,14 +24,15 @@ class SequencePoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of SequencePoolOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
PADDLE_ENFORCE(ctx->HasOutput("MaxIndex"),
"Output(MaxIndex) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MaxIndex"), true,
"Output(MaxIndex) of SequencePoolOp should not be null.");
ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X"));
}
}
......@@ -102,9 +103,10 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Gradient of Out should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"The input X should not be null.");
auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
......
......@@ -30,19 +30,30 @@ class SequencePoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<Tensor>("Out");
auto* out = context.Output<LoDTensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
T pad_value = static_cast<T>(context.Attr<float>("pad_value"));
auto dims = in->dims();
auto lod = in->lod();
auto lod_level = lod.size();
// InferShape by lod
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(lod_level, 1UL,
"The lod level of input shall be 1 at least.");
PADDLE_ENFORCE_LE(lod_level, 2UL,
"The lod level of input shall be no more than 2.");
PADDLE_ENFORCE_GE(
dims[0],
/*batch size = */ static_cast<int64_t>(lod[0].size() - 1),
/*batch size = */ static_cast<int64_t>(lod[lod_level - 1].size() - 1),
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
if (lod_level > 1UL) {
PADDLE_ENFORCE_EQ(lod[0][lod[0].size() - 1], lod[1].size() - 1,
"The input lod information is illegal.");
framework::LoD out_lod;
out_lod.push_back(lod[0]);
out->set_lod(out_lod);
}
dims[0] = lod[lod_level - 1].size() - 1;
out->Resize({dims});
out->mutable_data<T>(context.GetPlace());
Tensor* index = nullptr;
......@@ -68,7 +79,7 @@ template <typename DeviceContext, typename T>
class SequencePoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
std::string pooltype = context.Attr<std::string>("pooltype");
const Tensor* index = nullptr;
......
......@@ -21,30 +21,33 @@ from test_reorder_lod_tensor import convert_to_offset
def compute_seqpool_sum(x, offset, out, pad_value=0.0):
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = pad_value
else:
sub_x = x[offset[0][i]:offset[0][i + 1], :]
sub_x = x[offset[level][i]:offset[level][i + 1], :]
out[i] = sub_x.sum(axis=0)
def compute_seqpool_avg(x, offset, out, pad_value=0.0):
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = pad_value
else:
sub_x = x[offset[0][i]:offset[0][i + 1], :]
sub_x = x[offset[level][i]:offset[level][i + 1], :]
out[i] = sub_x.mean(axis=0)
def compute_seqpool_sqrt(x, offset, out, pad_value=0.0):
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = pad_value
else:
sub_x = x[offset[0][i]:offset[0][i + 1], :]
seq_len = offset[0][i + 1] - offset[0][i]
sub_x = x[offset[level][i]:offset[level][i + 1], :]
seq_len = offset[level][i + 1] - offset[level][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len)
......@@ -56,9 +59,10 @@ class TestSeqAvgPool(OpTest):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [11, 23]).astype('float32')
lod = self.set_lod()
level = len(lod) - 1
self.inputs = {'X': (x, lod)}
offset = convert_to_offset(lod)
out = np.zeros((len(lod[0]), 23)).astype('float32')
out = np.zeros((len(lod[level]), 23)).astype('float32')
self.outputs = {'Out': out}
return x, offset, out
......@@ -69,14 +73,18 @@ class TestSeqAvgPool(OpTest):
def setUp(self):
x, offset, out = self.set_data()
self.compute(x, offset, out)
if len(offset) > 1:
self.outputs = {'Out': (out, [self.set_lod()[0]])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
# Remove MaxIndex after check_grad is refined.
out = self.outputs['Out']
if isinstance(out, tuple): out = out[0]
self.outputs['MaxIndex'] = \
np.zeros(self.outputs['Out'].shape).astype('int32')
np.zeros(out.shape).astype('int32')
self.check_grad(["X"], "Out")
......@@ -85,6 +93,11 @@ class TestSeqAvgPoolLen0(TestSeqAvgPool):
return [[0, 4, 0, 7, 0]]
class TestSeqAvgPoolLen0LoDLevel2(TestSeqAvgPool):
def set_lod(self):
return [[2, 0, 1, 2], [0, 4, 0, 7, 0]]
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.1, 'pooltype': "SUM"}
......@@ -96,6 +109,11 @@ class TestSeqSumPoolLen0(TestSeqSumPool):
return [[0, 4, 0, 7, 0]]
class TestSeqSumPoolLen0LoDLevel2(TestSeqSumPool):
def set_lod(self):
return [[2, 0, 1, 2], [0, 4, 0, 7, 0]]
class TestSeqMaxPool(TestSeqAvgPool):
def set_lod(self):
return [[13]]
......@@ -104,25 +122,27 @@ class TestSeqMaxPool(TestSeqAvgPool):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = self.set_lod()
level = len(lod) - 1
offset = convert_to_offset(lod)
for i in range(len(offset[0]) - 1):
l = offset[0][i + 1] - offset[0][i]
for i in range(len(offset[level]) - 1):
l = offset[level][i + 1] - offset[level][i]
if l > 0:
x[offset[0][i] + np.random.randint(l), :] += 2.0
x[offset[level][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((len(lod[0]), 23)).astype('float32')
out = np.zeros((len(lod[level]), 23)).astype('float32')
self.outputs = {'Out': out}
return x, offset, out
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.5, 'pooltype': "MAX"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"]
else:
sub_x = x[offset[0][i]:offset[0][i + 1], :]
sub_x = x[offset[level][i]:offset[level][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
......@@ -131,6 +151,11 @@ class TestSeqMaxPoolLen0(TestSeqMaxPool):
return [[0, 1, 1, 5, 6, 0]]
class TestSeqMaxPoolLen0LoDLevel2(TestSeqMaxPool):
def set_lod(self):
return [[2, 0, 3, 1], [0, 1, 1, 5, 6, 0]]
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"}
......@@ -142,14 +167,20 @@ class TestSeqSqrtPoolLen0(TestSeqSqrtPool):
return [[0, 7, 0, 2, 2, 0]]
class TestSeqSqrtPoolLen0LoDLevel2(TestSeqSqrtPool):
def set_lod(self):
return [[1, 2, 0, 3], [0, 7, 0, 2, 2, 0]]
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"]
else:
sub_x = x[offset[0][i]:offset[0][i + 1], :]
sub_x = x[offset[level][i]:offset[level][i + 1], :]
out[i] = sub_x[-1, :]
......@@ -158,14 +189,20 @@ class TestSeqLastPoolLen0(TestSeqLastPool):
return [[0, 3, 4, 0, 4, 0]]
class TestSeqLastPoolLen0LoDLevel2(TestSeqLastPool):
def set_lod(self):
return [[1, 0, 2, 3], [0, 3, 4, 0, 4, 0]]
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.3, 'pooltype': "FIRST"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"]
else:
sub_x = x[offset[0][i]:offset[0][i + 1], :]
sub_x = x[offset[level][i]:offset[level][i + 1], :]
out[i] = sub_x[0, :]
......@@ -174,6 +211,11 @@ class TestSeqFirstPoolLen0(TestSeqFirstPool):
return [[0, 2, 0, 3, 6, 0]]
class TestSeqFirstPoolLen0LoDLevel2(TestSeqFirstPool):
def set_lod(self):
return [[1, 0, 2, 3], [0, 2, 0, 3, 6, 0]]
class TestSeqAvgPool2D(TestSeqAvgPool):
def set_lod(self):
return [[4, 1, 3, 5]]
......@@ -182,20 +224,22 @@ class TestSeqAvgPool2D(TestSeqAvgPool):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32')
lod = self.set_lod()
level = len(lod) - 1
self.inputs = {'X': (x, lod)}
offset = convert_to_offset(lod)
out = np.zeros((len(lod[0]), 3, 17)).astype('float32')
out = np.zeros((len(lod[level]), 3, 17)).astype('float32')
self.outputs = {'Out': out}
return x, offset, out
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "AVERAGE"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 17))
else:
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 17))
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
......@@ -205,14 +249,20 @@ class TestSeqAvgPool2DLen0(TestSeqAvgPool2D):
return [[0, 5, 0, 8, 0]]
class TestSeqAvgPool2DLen0LoDLevel2(TestSeqAvgPool2D):
def set_lod(self):
return [[1, 0, 4], [0, 5, 0, 8, 0]]
class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.2, 'pooltype': "SUM"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 17))
else:
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 17))
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
......@@ -222,23 +272,32 @@ class TestSeqSumPool2DLen0(TestSeqSumPool2D):
return [[0, 8, 0, 5, 0]]
class TestSeqSumPool2DLen0LoDLevel2(TestSeqSumPool2D):
def set_lod(self):
return [[1, 0, 4], [0, 8, 0, 5, 0]]
class TestSeqSqrtPool2D(TestSeqAvgPool2D):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 17))
else:
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 17))
seq_len = offset[0][i + 1] - offset[0][i]
seq_len = offset[level][i + 1] - offset[level][i]
out[i] = np.reshape(
sub_x.sum(axis=0) / np.sqrt(seq_len), (3, 17))
def test_check_grad(self):
# Remove MaxIndex after check_grad is refined.
out = self.outputs['Out']
if isinstance(out, tuple):
out = out[0]
self.outputs['MaxIndex'] = \
np.zeros(self.outputs['Out'].shape).astype('int32')
np.zeros(out.shape).astype('int32')
self.check_grad(["X"], "Out", max_relative_error=0.06)
......@@ -247,6 +306,11 @@ class TestSeqSqrtPool2DLen0(TestSeqSqrtPool2D):
return [[0, 8, 0, 5, 0]]
class TestSeqSqrtPool2DLen0LoDLevel2(TestSeqSqrtPool2D):
def set_lod(self):
return [[1, 0, 2, 2], [0, 8, 0, 5, 0]]
class TestSeqMaxPool2D(TestSeqAvgPool2D):
def set_lod(self):
return [[4, 1, 3, 5]]
......@@ -255,25 +319,27 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32')
self.lod = self.set_lod()
level = len(self.lod) - 1
self.inputs = {'X': (x, self.lod)}
offset = convert_to_offset(self.lod)
for i in range(len(offset[0]) - 1):
l = offset[0][i + 1] - offset[0][i]
for i in range(len(offset[level]) - 1):
l = offset[level][i + 1] - offset[level][i]
if l == 0:
continue
x[offset[0][i] + np.random.randint(l), :] += 1.0
x[offset[level][i] + np.random.randint(l), :] += 1.0
out = np.zeros((len(self.lod[0]), 3, 11)).astype('float32')
out = np.zeros((len(self.lod[level]), 3, 11)).astype('float32')
self.outputs = {'Out': out}
return x, offset, out
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "MAX"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 11))
continue
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 11))
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
......@@ -283,14 +349,20 @@ class TestSeqMaxPool2DLen0(TestSeqMaxPool2D):
return [[0, 3, 0, 10, 0]]
class TestSeqMaxPool2DLen0LoDLevel2(TestSeqMaxPool2D):
def set_lod(self):
return [[1, 0, 2, 2], [0, 3, 0, 10, 0]]
class TestSeqMaxPool2DInference(TestSeqMaxPool2D):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 1.0, 'pooltype': "MAX", 'is_test': True}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 11))
else:
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 11))
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
......@@ -305,14 +377,20 @@ class TestSeqMaxPool2DInferenceLen0(TestSeqMaxPool2DInference):
return [[0, 3, 0, 10, 0]]
class TestSeqMaxPool2DInferenceLen0LoDLevel2(TestSeqMaxPool2DInference):
def set_lod(self):
return [[1, 0, 2, 2], [0, 3, 0, 10, 0]]
class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 17))
else:
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 17))
out[i] = np.reshape(sub_x[-1, :], (3, 17))
......@@ -322,14 +400,20 @@ class TestSeqLastPool2DLen0(TestSeqLastPool2D):
return [[0, 3, 0, 1, 9, 0]]
class TestSeqLastPool2DLen0LoDLevel2(TestSeqLastPool2D):
def set_lod(self):
return [[1, 0, 2, 3], [0, 3, 0, 1, 9, 0]]
class TestSeqFirstPool2D(TestSeqAvgPool2D):
def compute(self, x, offset, out):
self.attrs = {"pad_value": 0.0, 'pooltype': "FIRST"}
for i in range(len(offset[0]) - 1):
if offset[0][i] == offset[0][i + 1]:
level = len(offset) - 1
for i in range(len(offset[level]) - 1):
if offset[level][i] == offset[level][i + 1]:
out[i] = self.attrs["pad_value"] * np.ones((3, 17))
else:
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
(-1, 3 * 17))
out[i] = np.reshape(sub_x[0, :], (3, 17))
......@@ -339,5 +423,10 @@ class TestSeqFirstPool2DLen0(TestSeqFirstPool2D):
return [[0, 3, 0, 3, 7, 0]]
class TestSeqFirstPool2DLen0LoDLevel2(TestSeqFirstPool2D):
def set_lod(self):
return [[1, 0, 2, 3], [0, 3, 0, 3, 7, 0]]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册