提交 1b01f1ea 编写于 作者: L Luo Tao

implement framework of seq_pool_op and its unitest

上级 d4d4580d
...@@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sequence_avg_pool_op.h" #include "paddle/operators/sequence_pool_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SequenceAvgPoolOp : public framework::OperatorWithKernel { class SequencePoolOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
ctx.InputVar("X"), "Input(X) of SequenceAvgPoolOp should not be null."); "Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"), ctx.OutputVar("Out"),
"Output(Out) of SequenceAvgPoolOp should not be null."); "Output(Out) of SequencePoolOp should not be null.");
auto* x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto dims = x->dims(); auto dims = x->dims();
...@@ -42,21 +42,44 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel { ...@@ -42,21 +42,44 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel {
} }
}; };
class SequenceAvgPoolOpMaker : public framework::OpProtoAndCheckerMaker { class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceAvgPoolOpMaker(framework::OpProto* proto, SequencePoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SequenceAvgPoolOp."); AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp");
AddOutput("Out", "The output of SequenceAvgPoolOp."); AddOutput("Out",
"A LoDTensor, the variable-length output of SequencePoolOp.");
AddAttr<int>(
"strategy",
"(int, default AVERAGE) the pooling strategy of SequencePoolOp.")
.SetDefault(AVERAGE)
.InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST});
AddComment(R"DOC( AddComment(R"DOC(
SequenceAvgPoolOp averages features of all time-steps of each instance. SequencePoolOp pools features of all time-steps of each instance.
More detailed comments will be added later.
For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 words:
X = [[1, 3], [2, 4, 6], [5, 1]],
and X->lod()[0] = [0, 2, 5, 7]
then, for different strategy, we get:
- AVERAGE: Out = [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
- SUM: Out = [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
- SQRT: Out = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3),
4.24=(5+1)/sqrt(2)
- MAX: Out = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
- LAST: Out = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
- FIRST: Out = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
and X->lod() is nullptr.
)DOC"); )DOC");
} }
}; };
class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { class SequencePoolGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -84,12 +107,10 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { ...@@ -84,12 +107,10 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sequence_avg_pool, ops::SequenceAvgPoolOp, REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
ops::SequenceAvgPoolOpMaker, sequence_avg_pool_grad, sequence_pool_grad, ops::SequencePoolGradOp);
ops::SequenceAvgPoolGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_avg_pool, sequence_pool, ops::SequencePoolKernel<paddle::platform::CPUPlace, float>);
ops::SequenceAvgPoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_avg_pool_grad, sequence_pool_grad,
ops::SequenceAvgPoolGradKernel<paddle::platform::CPUPlace, float>); ops::SequencePoolGradKernel<paddle::platform::CPUPlace, float>);
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/sequence_avg_pool_op.h" #include "paddle/operators/sequence_pool_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sequence_avg_pool, sequence_pool, ops::SequencePoolKernel<paddle::platform::GPUPlace, float>);
ops::SequenceAvgPoolKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sequence_avg_pool_grad, sequence_pool_grad,
ops::SequenceAvgPoolGradKernel<paddle::platform::GPUPlace, float>); ops::SequencePoolGradKernel<paddle::platform::GPUPlace, float>);
...@@ -28,54 +28,85 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,54 +28,85 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
enum SeqPoolType {
AVERAGE = 0,
SUM = 1,
SQRT = 2, // square_root_n
MAX = 3,
LAST = 4,
FIRST = 5
};
template <typename Place, typename T> template <typename Place, typename T>
class SequenceAvgPoolKernel : public framework::OpKernel { class SequencePoolKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
int strategy = context.Attr<int>("strategy");
auto dims = in->dims(); auto dims = in->dims();
auto lod = in->lod(); auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0]; int64_t w = in->numel() / dims[0];
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t = in->Slice<T>(static_cast<int>(lod[0][i]), Tensor in_t =
static_cast<int>(lod[0][i + 1])); in->Slice<T>(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Tensor out_t = out->Slice<T>(i, i + 1); Tensor out_t = out->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[0][i + 1] - lod[0][i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w})); auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t); auto out_e = EigenVector<T>::Flatten(out_t);
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
switch (strategy) {
case AVERAGE:
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
break;
case SUM:
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
break;
default:
LOG(FATAL) << "unsupported pooling strategy";
}
} }
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequenceAvgPoolGradKernel : public framework::OpKernel { class SequencePoolGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X")); auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
int strategy = context.Attr<int>("strategy");
auto dims = in->dims(); auto dims = in->dims();
auto lod = in->lod(); auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0]; int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[0][i]), auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[i]),
static_cast<int>(lod[0][i + 1])); static_cast<int>(lod[i + 1]));
auto out_g_t = out_g->Slice<T>(i, i + 1); auto out_g_t = out_g->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[0][i + 1] - lod[0][i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w}); auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w}); auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
Eigen::DSizes<int, 2> bcast(h, 1); Eigen::DSizes<int, 2> bcast(h, 1);
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
switch (strategy) {
case AVERAGE:
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
break;
case SUM:
in_g_e.device(place) = (out_g_e).broadcast(bcast);
break;
default:
LOG(FATAL) << "unsupported pooling strategy";
}
} }
} }
}; };
......
...@@ -3,20 +3,37 @@ import numpy as np ...@@ -3,20 +3,37 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
class TestSeqAvgPool1D(OpTest): class SeqPoolType(OpTest):
def setUp(self): AVERAGE = 0
self.op_type = 'sequence_avg_pool' SUM = 1
SQRT = 2
MAX = 3
LAST = 4
FIRST = 5
class TestSeqAvgPool(OpTest):
def set_data(self):
self.op_type = 'sequence_pool'
# one level, batch size is 4 # one level, batch size is 4
x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') x = np.random.uniform(0.1, 1, [11, 23]).astype('float32')
lod = [[0, 4, 5, 8, 11]] lod = [[0, 4, 5, 8, 11]]
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32') out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
def compute(self):
self.attrs = {'strategy': SeqPoolType.AVERAGE}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4): for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :] sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.mean(axis=0) out[i] = sub_x.mean(axis=0)
self.inputs = {'X': (x, lod)} def setUp(self):
self.outputs = {'Out': out} self.set_data()
self.compute()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -25,26 +42,44 @@ class TestSeqAvgPool1D(OpTest): ...@@ -25,26 +42,44 @@ class TestSeqAvgPool1D(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestSeqAvgPool2D(OpTest): class TestSeqAvgPool2D(TestSeqAvgPool):
def setUp(self): def set_data(self):
self.op_type = 'sequence_avg_pool' self.op_type = 'sequence_pool'
# one level, batch size is 4 # one level, batch size is 4
x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32')
lod = [[0, 4, 5, 8, 13]] lod = [[0, 4, 5, 8, 13]]
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 3, 17)).astype('float32') out = np.zeros((4, 3, 17)).astype('float32')
self.outputs = {'Out': out}
def compute(self):
self.attrs = {'strategy': SeqPoolType.AVERAGE}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4): for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
self.inputs = {'X': (x, lod)}
self.outputs = {'Out': out}
def test_check_output(self): class TestSeqSumPool(TestSeqAvgPool):
self.check_output() def compute(self):
self.attrs = {'strategy': SeqPoolType.SUM}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
def test_check_grad(self):
self.check_grad(["X"], "Out") class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self):
self.attrs = {'strategy': SeqPoolType.SUM}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册