提交 8e3ecf5d 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #4814 from chengduoZH/Add_sequence_project_op

Add sequence_conv_op  and sequence_projection functor
......@@ -123,6 +123,7 @@ set(DEPS_OPS
sum_op
pool_op
pool_with_index_op
sequence_conv_op
lstm_op)
......@@ -134,6 +135,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......
......@@ -9,6 +9,7 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
else()
......@@ -18,6 +19,7 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/context_project.h"
namespace paddle {
namespace operators {
namespace math {
template class ContextProjectFunctor<platform::CPUPlace, float>;
template class ContextProjectFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/math/context_project.h"
namespace paddle {
namespace operators {
namespace math {
template class ContextProjectFunctor<platform::GPUPlace, float>;
template class ContextProjectFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/im2col.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
/*
* \brief Context projection concatenate features in adjacent time steps in
* a sequence. The i-th row of the output is the concatenation of
* context_length rows of the input. The context_length rows are the
* consecutive rows from the i+shift_start row.
* \param in Input data.
* \param Shape The shape of Input data,
* [minibatch, number_of_input_features].
* \param type A float LoDTensor.
*
* \param padding_data Padding data.
* \param Shape The shape of Padding data,
* [up_pad + down_pad, number_of_input_features].
* \param type A float Tensor.
*
* \param col Col data.
* \param Shape The shape of Col data,
* [minibatch, context_length * number_of_input_features].
* \param type A float Tensor.
*
* For a mini-batch of 2 variable lengths sentences, containing 3, and 1
* time-steps:
*
* Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3,
* 4].
* Besides, for the sake of simplicity, we assume M=1 and N=2.
*
* X = [[a1, a2;
* b1, b2;
* c1, c2]
* [d1, d2]]
*
* This is to say that input (X) has 4 words and the dimension of each word
* representation is 2.
*
* - Case1:
* If context_start is -1 and padding_trainable is false, we use zero to pad
* instead of learned weight to pad,
* and the context_lenth is 3, the output (Out) is:
*
* Out =[[0, 0, a1, a2, b1, b2;
* a1, a2, b1, b2, c1, c2;
* b1, b2, c1, c2, 0, 0 ]
* [0, 0, d1, d2, 0, 0 ]]
*
* - Case2:
* If context_start is -1 and padding_trainable is true, we use learned weight
* to pad,
* and the context_lenth is 3, the output (Out) is:
*
* Out = [[w1, w2, a1, a2, b1, b2;
* a1, a2, b1, b2, c1, c2;
* b1, b2, c1, c2, w3, w4]
* [w1, w2, d1, d2, w3, w4]]
*
*/
template <typename Place, typename T>
class ContextProjectFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::LoDTensor& in, framework::Tensor& padding_data,
framework::Tensor& col, bool padding_trainable,
int context_start, int context_length, int context_stride,
int up_pad, int down_pad, bool gradient, bool input_grad,
bool pad_grad) {
auto lod_level_0 = in.lod()[0];
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
int input_row_begin, input_row_end;
int sequence_height, sequence_width;
sequence_width = in.dims()[1];
input_grad = gradient && input_grad;
pad_grad = gradient && pad_grad;
if (!gradient || input_grad) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
input_row_begin = (context_start > 0)
? static_cast<int>(lod_level_0[i]) + context_start
: static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]);
framework::Tensor out_t =
col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
if (input_row_begin < input_row_end) {
framework::Tensor in_t = in.Slice(input_row_begin, input_row_end);
std::vector<int64_t> output_shape(
{sequence_height, 1, 1, context_length,
sequence_width}); // output_height, output_width,
// input_channels, filter_height, filter_width
out_t.Resize(framework::make_ddim(output_shape));
std::vector<int64_t> input_shape(
{1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
if (gradient) {
col2im_ocf(context, in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
} else {
im2col_ocf(context, in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
}
out_t.Resize({sequence_height, context_length * sequence_width});
}
}
}
if (!gradient || pad_grad) {
if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
framework::Tensor out_t =
col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
// add up trainable data
out_t.Resize({sequence_height * context_length, sequence_width});
if (up_pad > 0) { // add up pad
int padding_rows = std::min(
up_pad, static_cast<int>(lod_level_0[i + 1] - lod_level_0[i]));
for (int k = 0; k < padding_rows; ++k) {
int padding_size =
k + context_length < up_pad ? context_length : up_pad - k;
framework::Tensor out_t_sub = out_t.Slice(
k * context_length, k * context_length + padding_size);
framework::Tensor w_sub = padding_data.Slice(k, k + padding_size);
// in this block, using EigenVector<T>::Flatten is ok too.
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
if (gradient) {
w_sub_e.device(*context.GetEigenDevice<Place>()) =
w_sub_e + out_t_sub_e;
} else {
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
}
}
}
if (down_pad > 0) { // add down pad
int down_pad_begin_row =
std::max(
0, (sequence_height - context_start - context_length) + 1) +
1;
int padding_begin = std::max(0, context_start - sequence_height);
int padding_size =
sequence_height - context_start >= context_length
? 1
: context_length - (sequence_height - context_start);
if (context_start >= sequence_height) padding_size = context_length;
int padding_idx = padding_begin;
for (int t = 0; t + down_pad_begin_row <= sequence_height;
++t, ++padding_size) {
if (context_start >= sequence_height)
padding_size = context_length;
if (padding_size > context_length) {
padding_size = context_length;
padding_idx++;
}
if (padding_begin > 0 || sequence_height == context_start)
padding_idx = padding_begin + t;
framework::Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
framework::Tensor w_sub = padding_data.Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
if (gradient) {
w_sub_e.device(*context.GetEigenDevice<Place>()) =
w_sub_e + out_t_sub_e;
} else {
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
}
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
}
}
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_conv_op.h"
namespace paddle {
namespace operators {
class SequenceConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceConvOp should not be null.");
int context_length = ctx->Attrs().Get<int>("context_length");
bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
int context_start = ctx->Attrs().Get<int>("context_start");
auto in_dims = ctx->GetInputDim("X");
auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(filter_dims[0] == context_length * in_dims[1],
"Filter's height should be context_length * "
"number_of_input_features .");
if (padding_trainable) {
PADDLE_ENFORCE(
ctx->HasInput("PaddingData"),
"Input(PaddingData) of SequenceConvOp should not be null.");
framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int total_pad = up_pad + down_pad;
int input_width = static_cast<int>(in_dims[1]);
if (context_start == 0 && context_length == 1) {
PADDLE_THROW(
"If context_start is 0 and context_length is 1, padding_trainable "
"should be false.");
}
PADDLE_ENFORCE(padding_dim.size() == 2,
"Input(PaddingData) should be 2-D tensor.");
PADDLE_ENFORCE(
padding_dim[0] == total_pad && padding_dim[1] == input_width,
"Input(PaddingData)'s shape is not consistent with 'context_start' "
"and 'context_length'.");
}
in_dims[1] = filter_dims[1];
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", "Out");
}
};
class SequenceConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null.");
if (ctx->Attrs().Get<bool>("padding_trainable") &&
ctx->HasOutput(framework::GradVarName("PaddingData"))) {
ctx->SetOutputDim(framework::GradVarName("PaddingData"),
ctx->GetInputDim("PaddingData"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"),
ctx->GetInputDim("Filter"));
}
}
};
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(LoDTensor) the input(X) is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, D), where, T is the "
"total time steps in this mini-batch, D is the input feature size.");
AddInput("PaddingData",
"(Tensor, optional) the input(PaddingData) is an optional "
"parameter, and it is learnable. "
"This is a tensor with shape (N, D), where N is the "
"top_pad + bottom_pad, D is the input feature size. In order to "
"ensure the equal length of sequence before and after "
"convolution, it is necessary to fill the top and bottom of each "
"sequence according to context_length, context_stride and "
"context_start")
.AsDispensable();
AddInput("Filter",
"(Tensor) the input(Filter) is an learnable parameter."
"This is a tensor with shape (N, D), where N is the "
"context_length, D is the output feature size.");
AddOutput(
"Out",
"(LoDTensor) the output(Out) is a LodTensor, which support "
"variable-time length output sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, D), where, T is the "
"total time steps in this mini-batch, D is the output feature size.");
AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceConvOp "
"is trainable or not.")
.SetDefault(false);
AddAttr<int>("context_length",
"(int, default 3) the context_length of SequenceConvOp is the "
"height of the convolution kernel.")
.SetDefault(3)
.GreaterThan(0);
AddAttr<int>("context_start",
"(int, default 0) the context_start of SequenceConvOp "
"represents the beginning of the convolution of the number of "
"rows of sequence, which can be negative.")
.SetDefault(0);
AddAttr<int>("context_stride",
"(int, default 1) the context_stride of SequenceConvOp "
"represents the step length of convolution. "
"Currently, SequenceConvOp only supports"
"context_stride=1.")
.SetDefault(1)
.GreaterThan(0);
AddComment(R"DOC(
SequenceConvOp performs convolution operation on features of
context_length time-steps of each instance.
The convolution operation calculates the output based on the input, filter
and strides, paddings parameters. The size of each dimension of the
parameters is checked in the infer-shape. In order to ensure the equal
length of sequence before and after convolution, it is necessary to fill
the top and bottom of each sequence according to context_length,
context_stride and context_start.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
sequence_conv_grad, ops::SequenceConvGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_conv_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/context_project.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class SequenceConvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
auto filter = *context.Input<Tensor>("Filter");
out->mutable_data<T>(context.GetPlace());
context.ShareLoD("X", "Out");
int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length");
int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable");
// InferShape by in_lod
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
const Tensor* padding_data = nullptr;
if (padding_trainable) {
padding_data = context.Input<Tensor>("PaddingData");
}
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width;
sequence_width = static_cast<int>(in->dims()[1]);
// Use col_shape in the im2col calculation.
framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
math::SetConstant<Place, T> set_zero;
// Because if padding_trainable is false, padding data should be zeros.
set_zero(context.device_context(), &col, static_cast<T>(0));
paddle::operators::math::ContextProjectFunctor<Place, T>
seq_project_functor;
LoDTensor* input = const_cast<LoDTensor*>(in);
Tensor* pad_data = const_cast<Tensor*>(padding_data);
seq_project_functor(context.device_context(), *input, *pad_data, col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, false, false, false);
math::matmul<Place, T>(context.device_context(), col, false, filter, false,
static_cast<T>(1.0), out, static_cast<T>(0.0));
}
};
template <typename Place, typename T>
class SequenceConvGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* filter_g = context.Output<Tensor>(framework::GradVarName("Filter"));
auto* padding_data_g =
context.Output<Tensor>(framework::GradVarName("PaddingData"));
auto* in = context.Input<LoDTensor>("X");
auto* filter = context.Input<Tensor>("Filter");
int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length");
int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable");
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
auto lod_g_level_0 = in->lod()[0];
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]);
math::SetConstant<Place, T> set_zero;
// use col_shape in the im2col calculation
framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length};
Tensor col;
if (in_g || filter_g || (padding_trainable && padding_data_g)) {
col.mutable_data<T>(col_shape, context.GetPlace());
// Because if padding_trainable is false, padding data should be zeros.
set_zero(context.device_context(), &col, static_cast<T>(0));
math::matmul<Place, T>(context.device_context(), *out_g, false, *filter,
true, T(1.0), &col, T(1.0));
}
paddle::operators::math::ContextProjectFunctor<Place, T>
seq_project_functor;
if (in_g) {
in_g->mutable_data<T>(context.GetPlace());
in_g->set_lod(in->lod());
set_zero(context.device_context(), in_g, static_cast<T>(0));
seq_project_functor(context.device_context(), *in_g, *padding_data_g, col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, true, true, false);
}
if (padding_trainable && padding_data_g) {
padding_data_g->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), padding_data_g, static_cast<T>(0));
LoDTensor* input = const_cast<LoDTensor*>(in);
seq_project_functor(context.device_context(), *input, *padding_data_g,
col, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, true, false, true);
}
if (filter_g) {
filter_g->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), filter_g, static_cast<T>(0));
Tensor filter_grad = *filter_g;
LoDTensor out_grad = *out_g;
const Tensor* padding_data = nullptr;
if (padding_trainable) {
padding_data = context.Input<Tensor>("PaddingData");
}
sequence_width = static_cast<int>(in->dims()[1]);
LoDTensor* input = const_cast<LoDTensor*>(in);
Tensor* pad_data = const_cast<Tensor*>(padding_data);
seq_project_functor(context.device_context(), *input, *pad_data, col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, false, false,
false);
math::matmul<Place, T>(context.device_context(), col, true, out_grad,
false, T(1.0), &filter_grad, T(1.0));
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
import random
from op_test import OpTest
class TestSeqProject(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_conv'
if self.context_length == 1 \
and self.context_start == 0 \
and self.padding_trainable:
print "If context_start is 0 " \
"and context_length is 1," \
" padding_trainable should be false."
return
# one level, batch size
x = np.random.uniform(0.1, 1, [self.input_size[0],
self.input_size[1]]).astype('float32')
w = np.random.uniform(0.1, 1, [
self.context_length * self.input_size[1], self.output_represention
]).astype('float32')
begin_pad = np.max([0, -self.context_start])
end_pad = np.max([0, self.context_start + self.context_length - 1])
total_pad = begin_pad + end_pad
padding_data = np.random.uniform(
0.1, 1, [total_pad, self.input_size[1]]).astype('float32')
self.pad_data = padding_data
self.inputs = {
'X': (x, self.lod),
'Filter': w,
}
self.inputs_val = ['X', 'Filter']
self.inputs_val_no_x = ['Filter']
self.inputs_val_no_f = ['X']
if total_pad != 0:
self.inputs['PaddingData'] = padding_data
self.inputs_val = ['X', 'PaddingData', 'Filter']
self.inputs_val_no_x = ['PaddingData', 'Filter']
self.inputs_val_no_f = ['PaddingData', 'X']
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros(
(self.input_size[0], self.output_represention)).astype('float32')
self.outputs = {'Out': out}
self.compute()
def compute(self):
x, lod = self.inputs['X']
filter = self.inputs['Filter']
pading_data = self.pad_data
out = np.zeros((self.input_size[0], self.context_length *
self.input_size[1])).astype('float32')
lod = lod[0]
begin_pad = np.max([0, -self.context_start])
for i in range(len(lod) - 1):
for j in range(self.context_length):
in_begin = lod[i] + self.context_start + j
in_end = lod[i + 1] + self.context_start + j
out_begin = lod[i]
out_end = lod[i + 1]
if in_begin < lod[i]:
pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = pading_data[j:j + pad_size, :]
out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:(
j + 1) * self.input_size[1]] = sub_w
out_begin = lod[i] + pad_size
in_begin = lod[i]
if in_end > lod[i + 1]:
pad_size = np.min(
[in_end - lod[i + 1], lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = pading_data[begin_pad + self.context_start + j -
pad_size:begin_pad +
self.context_start + j, :]
out[lod[i + 1] - pad_size:lod[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
in_end = lod[i + 1]
out_end = lod[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub
np.dot(out, filter, out=self.outputs['Out'])
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.padding_trainable:
self.check_grad(
set(self.inputs_val), 'Out', max_relative_error=0.05)
def test_check_grad_input(self):
self.check_grad(
['X'],
'Out',
max_relative_error=0.05,
no_grad_set=set(self.inputs_val_no_x))
def test_check_grad_padding_data(self):
if self.padding_trainable:
self.check_grad(
['PaddingData'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['X', 'Filter']))
def test_check_grad_Filter(self):
self.check_grad(
['Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(self.inputs_val_no_f))
def test_check_grad_input_filter(self):
if self.padding_trainable:
self.check_grad(
['X', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['PaddingData']))
def test_check_grad_padding_input(self):
if self.padding_trainable:
self.check_grad(
self.inputs_val_no_f,
'Out',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_padding_filter(self):
if self.padding_trainable:
self.check_grad(
self.inputs_val_no_x,
'Out',
max_relative_error=0.05,
no_grad_set=set(['X']))
def init_test_case(self):
self.input_row = 11
self.context_start = 0
self.context_length = 1
self.padding_trainable = False
self.context_stride = 1
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
self.output_represention = 8 # output feature size
class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self):
self.input_row = 11
self.context_start = -1
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
self.output_represention = 8 # output feature size
class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self):
self.input_row = 25
self.context_start = 2
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1
self.input_size = [self.input_row, 23]
idx = range(self.input_size[0])
del idx[0]
self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
[self.input_size[0]]]
self.output_represention = 8 # output feature size
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册