From 339e655aeccba6bb109b3ec854e3a57296f558b5 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 19 Oct 2018 16:03:06 +0800 Subject: [PATCH] refine and add seqconv elementwiseadd relu op test --- .../fusion_seqconv_eltadd_relu_op.cc | 40 ++++---- .../test_fusion_seqconv_eltadd_relu_op.py | 94 ++++++++++++++++++ .../fluid/tests/unittests/test_seq_conv.py | 99 +++++++++---------- 3 files changed, 164 insertions(+), 69 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py diff --git a/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc index efeb18e161a..b0910dc19ed 100644 --- a/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc @@ -40,6 +40,7 @@ void FusionSeqConvEltAddReluOp::InferShape( auto x_dims = ctx->GetInputDim("X"); auto w_dims = ctx->GetInputDim("Filter"); + int context_length = ctx->Attrs().Get("contextLength"); PADDLE_ENFORCE( ctx->Attrs().Get("contextStride") == 1, "Currently, FusionSeqConvEltAddReluOp only supports contextStride=1."); @@ -47,10 +48,11 @@ void FusionSeqConvEltAddReluOp::InferShape( "Input(X, Filter) should be 2-D tensor."); PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2, "Input(X, Filter) should be 2-D tensor."); - PADDLE_ENFORCE( - w_dims[0] == ctx->Attrs().Get("contextLength") * x_dims[1], - "Filter's height should be context_length * " - "input_hidden_size ."); + PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1], + "Filter's height should be context_length * " + "input_hidden_size ."); + PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get("contextStart"), 0, + "contextStart size should be smaller than contextLength."); ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]}); ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]}); @@ -156,9 +158,8 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel { T* dst_data = col_data + st * col_mat_w; int seq_len = ed - st; if (seq_len > up_pad + down_pad) { - // zero all up_pad + // zero all up_pad and fill data std::memset(dst_data, 0, up_pad * col_mat_w_sz); - // fill up_pad data dst_data = dst_data + up_pad * src_mat_w; int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz; for (int j = 0; j < up_pad; ++j) { @@ -173,9 +174,8 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel { dst_data += col_mat_w; src_data += src_mat_w; } - // zero all down_pad + // zero all down_pad and fill data std::memset(dst_data, 0, down_pad * col_mat_w_sz); - // fill down_pad data copy_size -= src_mat_w_sz; for (int j = 0; j < down_pad; ++j) { std::memcpy(dst_data, src_data, copy_size); @@ -186,27 +186,29 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel { } else { PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1); std::memset(dst_data, 0, seq_len * col_mat_w_sz); + dst_data = dst_data + up_pad * src_mat_w; int zero_sz = up_pad * src_mat_w_sz; - int seq_len_size = seq_len * src_mat_w_sz; + int cur_src_sz = seq_len * src_mat_w_sz; for (int j = 0; j < std::min(up_pad, seq_len); ++j) { - int copy_size = std::min(seq_len_size, col_mat_w_sz - zero_sz); - std::memcpy(dst_data + zero_sz / sizeof(T), src_data, copy_size); - dst_data += col_mat_w; + int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); + std::memcpy(dst_data, src_data, copy_size); + dst_data += (col_mat_w - src_mat_w); zero_sz -= src_mat_w_sz; } + // from bottom + dst_data = col_data + ed * col_mat_w; + src_data = x_data + st * src_mat_w; zero_sz = down_pad * src_mat_w_sz; - dst_data = col_data + (ed - 1) * col_mat_w; - src_data = x_data + (ed - up_pad - 1) * src_mat_w; - for (int j = 0; j < std::min(0, seq_len - up_pad); ++j) { - int copy_size = std::min(seq_len_size, col_mat_w_sz - zero_sz); - std::memcpy(dst_data, src_data, copy_size); + for (int j = 1; j <= std::min(down_pad, seq_len); ++j) { + int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); + std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T), + src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w, + copy_size); dst_data -= col_mat_w; - src_data += src_mat_w; zero_sz -= src_mat_w_sz; } } } - auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::FCCompute(blas, x_dims[0], w_dims[1], w_dims[0], diff --git a/python/paddle/fluid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py b/python/paddle/fluid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py new file mode 100644 index 00000000000..ba6f1415b1c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py @@ -0,0 +1,94 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import random +from op_test import OpTest +from test_seq_conv import seqconv + + +class TestSeqConvEltAddRelu(OpTest): + def set_conf(self): + pass + + def setUp(self): + self.op_type = 'fusion_seqconv_eltadd_relu' + self.lod = [[6, 4]] + self.in_fea_size = 16 + self.out_fea_size = 8 + self.context_length = 4 + self.context_stride = 1 + self.context_start = 0 + self.set_conf() + + assert self.context_stride == 1 + + T = sum(self.lod[0]) + x = np.random.uniform(-1, 1, [T, self.in_fea_size]).astype('float32') + w = np.random.uniform( + -1, 1, [self.in_fea_size * self.context_length, + self.out_fea_size]).astype('float32') + b = np.random.uniform(-2, 1, [1, self.out_fea_size]).astype('float32') + out = seqconv(x, self.lod, w, self.context_length, self.context_start) + out = np.maximum(out + b, 0) + + self.inputs = {'X': (x, self.lod), 'Filter': w, 'Bias': b} + self.attrs = { + 'contextStart': self.context_start, + 'contextLength': self.context_length, + 'contextStride': self.context_stride + } + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu): + def set_conf(self): + self.lod = [[10]] + + +class TestSeqConvEltAddReluBS1Case2(TestSeqConvEltAddRelu): + def set_conf(self): + self.lod = [[2]] + + +class TestSeqConvEltAddReluCase1(TestSeqConvEltAddRelu): + def set_conf(self): + self.lod = [[3, 5, 1, 6]] + self.context_length = 3 + self.context_start = -2 + + +class TestSeqConvEltAddReluCase2(TestSeqConvEltAddRelu): + def set_conf(self): + self.lod = [[10, 1, 2, 4, 1, 5, 6]] + self.in_fea_size = 2 + self.context_length = 4 + self.context_start = -1 + + +class TestSeqConvEltAddReluCase3(TestSeqConvEltAddRelu): + def set_conf(self): + self.lod = [[10, 1, 2, 4, 1, 5, 6]] + self.context_length = 5 + self.context_start = -4 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_seq_conv.py b/python/paddle/fluid/tests/unittests/test_seq_conv.py index dcc86382e52..2285e949676 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_conv.py +++ b/python/paddle/fluid/tests/unittests/test_seq_conv.py @@ -20,6 +20,53 @@ import random from op_test import OpTest +def seqconv(x, + lod, + filter, + context_length, + context_start, + padding_trainable=False, + padding_data=None): + [T, M] = x.shape + col = np.zeros((T, context_length * M)).astype('float32') + offset = [0] + for seq_len in lod[0]: + offset.append(offset[-1] + seq_len) + begin_pad = np.max([0, -context_start]) + for i in range(len(offset) - 1): + for j in range(context_length): + in_begin = offset[i] + context_start + j + in_end = offset[i + 1] + context_start + j + out_begin = offset[i] + out_end = offset[i + 1] + if in_begin < offset[i]: + pad_size = np.min( + [offset[i] - in_begin, offset[i + 1] - offset[i]]) + if padding_trainable: + sub_w = padding_data[j:j + pad_size, :] + col[offset[i]:offset[i] + pad_size, j * M:(j + 1) * + M] = sub_w + out_begin = offset[i] + pad_size + in_begin = offset[i] + + if in_end > offset[i + 1]: + pad_size = np.min( + [in_end - offset[i + 1], offset[i + 1] - offset[i]]) + if padding_trainable: + sub_w = padding_data[begin_pad + context_start + j - + pad_size:begin_pad + context_start + + j, :] + col[offset[i + 1] - pad_size:offset[i + 1], j * M:(j + 1) * + M] = sub_w + in_end = offset[i + 1] + out_end = offset[i + 1] - pad_size + if in_end <= in_begin: + continue + in_sub = x[in_begin:in_end, :] + col[out_begin:out_end, j * M:(j + 1) * M] += in_sub + return np.dot(col, filter) + + class TestSeqProject(OpTest): def setUp(self): self.init_test_case() @@ -66,57 +113,9 @@ class TestSeqProject(OpTest): 'paddingTrainable': self.padding_trainable, 'contextStride': self.context_stride } - out = np.zeros( - (self.input_size[0], self.output_represention)).astype('float32') + out = seqconv(x, self.lod, w, self.context_length, self.context_start, + self.padding_trainable, self.pad_data) self.outputs = {'Out': out} - self.compute() - - def compute(self): - x, lod = self.inputs['X'] - filter = self.inputs['Filter'] - pading_data = self.pad_data - out = np.zeros((self.input_size[0], self.context_length * - self.input_size[1])).astype('float32') - offset = [0] - for seq_len in lod[0]: - offset.append(offset[-1] + seq_len) - begin_pad = np.max([0, -self.context_start]) - - for i in range(len(offset) - 1): - for j in range(self.context_length): - in_begin = offset[i] + self.context_start + j - in_end = offset[i + 1] + self.context_start + j - out_begin = offset[i] - out_end = offset[i + 1] - if in_begin < offset[i]: - pad_size = np.min( - [offset[i] - in_begin, offset[i + 1] - offset[i]]) - if self.padding_trainable: - sub_w = pading_data[j:j + pad_size, :] - out[offset[i]:offset[i] + pad_size, j * self.input_size[ - 1]:(j + 1) * self.input_size[1]] = sub_w - out_begin = offset[i] + pad_size - in_begin = offset[i] - - if in_end > offset[i + 1]: - pad_size = np.min( - [in_end - offset[i + 1], offset[i + 1] - offset[i]]) - if self.padding_trainable: - sub_w = pading_data[begin_pad + self.context_start + j - - pad_size:begin_pad + - self.context_start + j, :] - out[offset[i + 1] - pad_size:offset[i + 1], j * self. - input_size[1]:(j + 1) * self.input_size[1]] = sub_w - in_end = offset[i + 1] - out_end = offset[i + 1] - pad_size - if in_end <= in_begin: - continue - - in_sub = x[in_begin:in_end, :] - out[out_begin:out_end, j * self.input_size[1]:(j + 1) * - self.input_size[1]] += in_sub - - np.dot(out, filter, out=self.outputs['Out']) def test_check_output(self): self.check_output() -- GitLab