From 8de04be786fe21a72b9be91dab963f5d7520885b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 18 Oct 2017 17:14:38 +0800 Subject: [PATCH] Fix unitest --- paddle/framework/lod_tensor.cc | 29 +++++++ paddle/framework/lod_tensor.h | 7 ++ paddle/operators/seq_expand_op.h | 79 +++++-------------- .../v2/framework/tests/test_seq_expand.py | 30 ++----- 4 files changed, 64 insertions(+), 81 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 5b7badf89c..1247daafc5 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -103,5 +103,34 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, lod_ = new_lod; } +Vector repeat_lod(Vector data, Vector starts, + Vector times, bool is_first) { + Vector result; + result.push_back(data[0]); + size_t p = 0, start = 0, end = 0; + if (is_first == true) { + for (size_t i = 0; i < times.size(); ++i) { + result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); + } + } else { + for (size_t i = 0; i < times.size(); ++i) { + while (starts[i] != data[p] && p < data.size()) { + ++p; + } + start = p; + while (starts[i + 1] != data[p] && p < data.size()) { + ++p; + } + end = p + 1; + for (size_t j = 0; j < times[i]; ++j) { + for (size_t index = start; index < end - 1; ++index) { + result.push_back(result.back() + data[index + 1] - data[index]); + } + } + } + } + return result; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 4db36ee766..41c83a1164 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -15,6 +15,9 @@ #pragma once #include +#include "paddle/memory/memcpy.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" #ifdef PADDLE_WITH_CUDA #include #include @@ -122,5 +125,9 @@ class LoDTensor : public Tensor { private: LoD lod_; }; + +Vector repeat_lod(Vector data, Vector starts, + Vector times, bool is_first); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index cd1182c4f0..221393f909 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -22,54 +22,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; -template -using vector = framework::Vector; - -vector repeat_lod(vector data, vector starts, - vector times, bool is_first) { - vector result; - result.push_back(data[0]); - size_t p = 0, start = 0, end = 0; - if (is_first == true) { - for (size_t i = 0; i < times.size(); ++i) { - result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); - } - } else { - for (size_t i = 0; i < times.size(); ++i) { - while (starts[i] != data[p] && p < data.size()) { - ++p; - } - start = p; - while (starts[i + 1] != data[p] && p < data.size()) { - ++p; - } - end = p + 1; - for (size_t j = 0; j < times[i]; ++j) { - for (size_t index = start; index < end - 1; ++index) { - result.push_back(result.back() + data[index + 1] - data[index]); - } - } - } - } - return result; -} - -template -void repeat_data(const T* src, T* dst, size_t size, vector starts, - vector times, Place place) { - const T* src_p = src; - T* dst_p = dst; - size_t count = 0; - for (size_t i = 0; i < times.size(); ++i) { - count = size * (starts[i + 1] - starts[i]); - for (size_t j = 0; j < times[i]; ++j) { - memory::Copy(place, dst_p, place, src_p, sizeof(T) * count); - dst_p += count; - } - src_p += count; - } -} - template class SeqExpandKernel : public framework::OpKernel { public: @@ -81,7 +33,7 @@ class SeqExpandKernel : public framework::OpKernel { auto x_lod = x->lod(); if (x_lod.size() == 0) { - vector level; + framework::Vector level; for (int i = 0; i < x->dims()[0] + 1; ++i) { level.push_back(i); } @@ -91,7 +43,7 @@ class SeqExpandKernel : public framework::OpKernel { } size_t repeat = static_cast(context.Attr("repeat")); - vector repeats; + framework::Vector repeats; if (repeat != 0) { for (int i = 0; i < x_lod[0].size() - 1; ++i) { repeats.push_back(repeat); @@ -107,21 +59,32 @@ class SeqExpandKernel : public framework::OpKernel { repeats.push_back((y_lod[0][i + 1] - y_lod[0][i]) / (x_lod[0][i + 1] - x_lod[0][i])); } - out->Resize(x_dims); + out->Resize(y->dims()); } framework::LoD out_lod; - auto level0 = repeat_lod(x_lod[0], x_lod[0], repeats, true); + auto level0 = framework::repeat_lod(x_lod[0], x_lod[0], repeats, true); out_lod.push_back(level0); for (int i = 1; i < x_lod.size(); ++i) { - out_lod.push_back(repeat_lod(x_lod[i], x_lod[0], repeats, false)); + out_lod.push_back( + framework::repeat_lod(x_lod[i], x_lod[0], repeats, false)); } size_t element_len = framework::product(x_dims) / x_dims[0]; T* out_data = out->mutable_data(context.GetPlace()); + + // copy data Place place = boost::get(context.GetPlace()); - repeat_data(x_data, out_data, element_len, x_lod[0], repeats, - place); + size_t count = 0; + for (size_t i = 0; i < repeats.size(); ++i) { + count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); + for (size_t j = 0; j < repeats[i]; ++j) { + memory::Copy(place, out_data, place, x_data, sizeof(T) * count); + out_data += count; + } + x_data += count; + } + out->set_lod(out_lod); } }; @@ -130,9 +93,9 @@ template class SeqExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - // auto* d_out = context.Input(framework::GradVarName("Out")); - // auto* d_x = context.Output(framework::GradVarName("X")); - // d_x->mutable_data(context.GetPlace()); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + d_x->mutable_data(context.GetPlace()); } }; diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 854148a8f1..2b9509413e 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -29,17 +29,13 @@ def repeat_array(array, starts, times): class TestSeqExpand(OpTest): def set_data(self): - self.op_type = 'seq_expand' - x = np.random.uniform(0.1, 1, [3, 2, 2]).astype('float32') - y = np.zeros((6, 2, 2)).astype('float32') - y_lod = [[0, 2, 3, 6]] - self.inputs = {'X': (x, None), 'Y': (y, y_lod)} + x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + self.inputs = {'X': x_data} self.repeat = 2 def compute(self): - x_data, x_lod = self.inputs['X'] - print "x_data: %s" % x_data - print "x_lod: %s" % x_lod + x = self.inputs['X'] + x_data, x_lod = x if type(x) == tuple else (x, None) if not x_lod: x_lod = [[i for i in range(1 + x_data.shape[0])]] else: @@ -47,28 +43,16 @@ class TestSeqExpand(OpTest): if self.repeat: self.attrs = {'repeat': self.repeat} repeats = (len(x_lod[0]) - 1) * [self.repeat] - # get out shape - # out_shape = np.copy(x_data.shape) - # out_shape[0] = out_shape[0] * self.repeat else: y_data, y_lod = self.inputs['Y'] - print "y_lod: %s" % y_lod - #print "y_lod: %s" % y_lod - # get repeats repeats = [((y_lod[0][i + 1] - y_lod[0][i]) / (x_lod[0][i + 1] - x_lod[0][i])) for i in range(len(y_lod[0]) - 1)] - # get out shape - # out_shape = y_data.shape - # get out lod - out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [ repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:] ] - # copy data out = repeat_array(x_data.tolist(), x_lod[0], repeats) - self.outputs = {'Out': (out, out_lod)} - print "outputs: %s" % self.outputs + self.outputs = {'Out': out} def setUp(self): self.op_type = 'seq_expand' @@ -94,7 +78,7 @@ class TestSeqExpandCase1(TestSeqExpand): class TestSeqExpandCase2(TestSeqExpand): def set_data(self): x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') - self.inputs = {'X': (x_data, None)} + self.inputs = {'X': x_data} self.repeat = 2 @@ -103,7 +87,7 @@ class TestSeqExpandCase3(TestSeqExpand): x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') y_lod = [[0, 1, 4, 8]] - self.inputs = {'X': (x_data, None), 'Y': (y_data, y_lod)} + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} self.repeat = None -- GitLab