提交 f2ccef26 编写于 作者: C chengduoZH

Add sequence_conv_op

上级 0ab2c436
...@@ -115,7 +115,8 @@ set(DEPS_OPS ...@@ -115,7 +115,8 @@ set(DEPS_OPS
softmax_with_cross_entropy_op softmax_with_cross_entropy_op
sum_op sum_op
pool_op pool_op
pool_with_index_op) pool_with_index_op
sequence_conv_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
...@@ -126,6 +127,8 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) ...@@ -126,6 +127,8 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op) op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling) op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
op_library(sequence_conv_op DEPS sequence_project)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
......
...@@ -12,34 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,34 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sequence_project_op.h" #include "paddle/operators/sequence_conv_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SequenceProjectOp : public framework::OperatorWithKernel { class SequenceConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceProjectOp should not be null."); "Input(X) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceProjectOp should not be null."); "Output(Out) of SequenceConvOp should not be null.");
// PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() > // PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() >
// 0 failed, 0 <= 0) // 0 failed, 0 <= 0)
PADDLE_ENFORCE( PADDLE_ENFORCE(ctx->HasInput("PaddingData"),
ctx->HasInput("PaddingData"), "Input(PaddingData) of SequenceConvOp should not be null.");
"Input(PaddingData) of SequenceProjectOp should not be null.");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(in_dims.size() == 2, "Input(X) should be 2-D tensor.");
int context_length = ctx->Attrs().Get<int>("context_length"); int context_length = ctx->Attrs().Get<int>("context_length");
bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable"); bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
int context_start = ctx->Attrs().Get<int>("context_start"); int context_start = ctx->Attrs().Get<int>("context_start");
auto in_dims = ctx->GetInputDim("X");
auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(
filter_dims[0] == context_length && filter_dims[1] == in_dims[1],
"Filter's shape should be (context_length x "
"number_of_input_features).");
if (padding_trainable) { if (padding_trainable) {
framework::DDim padding_dim = ctx->GetInputDim("PaddingData"); framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
int up_pad = std::max(0, -context_start); int up_pad = std::max(0, -context_start);
...@@ -60,12 +67,12 @@ class SequenceProjectOp : public framework::OperatorWithKernel { ...@@ -60,12 +67,12 @@ class SequenceProjectOp : public framework::OperatorWithKernel {
"and 'context_length'."); "and 'context_length'.");
} }
in_dims[1] = in_dims[1] * context_length; in_dims[1] = 1;
ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Out", in_dims);
} }
}; };
class SequenceProjectGradOp : public framework::OperatorWithKernel { class SequenceConvGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -77,60 +84,66 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel { ...@@ -77,60 +84,66 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("padding_trainable") && if (ctx->Attrs().Get<bool>("padding_trainable") &&
ctx->HasOutput(framework::GradVarName("PaddingData"))) { ctx->HasOutput(framework::GradVarName("PaddingData"))) {
auto padding_dims = ctx->GetInputDim("PaddingData"); ctx->SetOutputDim(framework::GradVarName("PaddingData"),
ctx->SetOutputDim(framework::GradVarName("PaddingData"), padding_dims); ctx->GetInputDim("PaddingData"));
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"),
ctx->GetInputDim("Filter"));
}
} }
}; };
class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceProjectOpMaker(framework::OpProto* proto, SequenceConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(A float LoDTensor) the input of SequenceProjectOp, a vector of " "(A float LoDTensor) the input of SequenceConvOp, a vector of "
"2-D matrix of size (minibatch, number_of_input_features)."); "2-D matrix of size (minibatch, number_of_input_features).");
AddOutput("Out",
"(A float LoDTensor) the output of SequenceProjectOp, a vector "
"of 2-D matrix of size (minibatch, number_of_input_features x "
"context_length).");
AddInput("PaddingData", AddInput("PaddingData",
"(A float LoDTensor) the input of SequenceProjectOp, a vector of " "(A float LoDTensor) the input of SequenceConvOp, a vector of "
"2-D matrix of size (up_pad + down_pad, " "2-D matrix of size (up_pad + down_pad, "
"number_of_input_features). "); "number_of_input_features). ");
AddInput("Filter",
"(A float LoDTensor) the input of SequenceConvOp, a vector of "
"2-D matrix of size (context_length x number_of_input_features).");
AddOutput("Out",
"(A float LoDTensor) the output of SequenceConvOp, a vector "
"of 2-D matrix of size (minibatch, 1).");
AddAttr<bool>("padding_trainable", AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceProjectOp " "(bool, default false) the padding data of SequenceConvOp "
"is trainable or not.") "is trainable or not.")
.SetDefault(false); .SetDefault(false);
AddAttr<int>("context_length", AddAttr<int>("context_length",
"(int, default 3) the context_length of SequenceProjectOp.") "(int, default 3) the context_length of SequenceConvOp.")
.SetDefault(3) .SetDefault(3)
.GreaterThan(0); .GreaterThan(0);
AddAttr<int>("context_start", AddAttr<int>("context_start",
"(int, default 0) the context_start of SequenceProjectOp.") "(int, default 0) the context_start of SequenceConvOp.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("context_stride", AddAttr<int>("context_stride",
"(int, default 1) the context_stride of SequenceProjectOp. " "(int, default 1) the context_stride of SequenceConvOp. "
"Currently, sequence_project_op only support " "Currently, sequence_project_op only support "
"context_stride=1.") "context_stride=1.")
.SetDefault(1) .SetDefault(1)
.GreaterThan(0); .GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
SequenceProjectOp projects features of context_length time-steps of each instance. SequenceConvOp projects features of context_length time-steps of each instance.
For a mini-batch of 2 variable lengths sentences, containing 3, and 1 time-steps: For a mini-batch of 2 variable lengths sentences, containing 3, and 1 time-steps:
Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3, 4]. Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3, 4].
Besides, for the sake of simplicity, we assume M=1 and N=2. Besides, for the sake of simplicity, we assume M=1 and N=2.
X = [[a1, a2, X = [[a1, a2;
b1, b2. b1, b2;
c1, c2] c1, c2]
[d1, d2]] [d1, d2]]
...@@ -141,19 +154,19 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -141,19 +154,19 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
If context_start is -1 and padding_trainable is false, we use zero to pad instead of learned weight to pad, If context_start is -1 and padding_trainable is false, we use zero to pad instead of learned weight to pad,
and the context_lenth is 3, the output (Out) is: and the context_lenth is 3, the output (Out) is:
Out = [0, 0, a1, a2, b1, b2; Out =[[0, 0, a1, a2, b1, b2;
a1, a2, b1, b2, c1, c2; a1, a2, b1, b2, c1, c2;
b1, b2, c1, c2, 0, 0; b1, b2, c1, c2, 0, 0 ]
0, 0, d1, d2, 0, 0] [0, 0, d1, d2, 0, 0 ]]
- Case2: - Case2:
If context_start is -1 and padding_trainable is true, we use learned weight to pad, If context_start is -1 and padding_trainable is true, we use learned weight to pad,
and the context_lenth is 3, the output (Out) is: and the context_lenth is 3, the output (Out) is:
Out = [w1, w2, a1, a2, b1, b2; Out = [[w1, w2, a1, a2, b1, b2;
a1, a2, b1, b2, c1, c2; a1, a2, b1, b2, c1, c2;
b1, b2, c1, c2, w3, w4; b1, b2, c1, c2, w3, w4]
w1, w2, d1, d2, w3, w4] [w1, w2, d1, d2, w3, w4]]
)DOC"); )DOC");
} }
...@@ -163,13 +176,11 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -163,13 +176,11 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sequence_project, ops::SequenceProjectOp, REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
ops::SequenceProjectOpMaker, sequence_project_grad, sequence_conv_grad, ops::SequenceConvGradOp);
ops::SequenceProjectGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_project, sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>);
ops::SequenceProjectKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_project_grad, sequence_conv_grad,
ops::SequenceProjectGradKernel<paddle::platform::CPUPlace, float>); ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>);
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/sequence_project_op.h" #include "paddle/operators/sequence_conv_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sequence_project, sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>);
ops::SequenceProjectKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sequence_project_grad, sequence_conv_grad,
ops::SequenceProjectGradKernel<paddle::platform::GPUPlace, float>); ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>);
...@@ -15,46 +15,39 @@ limitations under the License. */ ...@@ -15,46 +15,39 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h" #include "paddle/operators/math/sequence_project.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor, // template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> // typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; // using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SequenceProjectKernel : public framework::OpKernel<T> { class SequenceConvKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace()); auto filter = *context.Input<LoDTensor>("Filter");
// Because if padding_trainable is false, padding data should be zeros.
auto temp = framework::EigenVector<T>::Flatten(*out);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
auto place = context.GetEigenDevice<Place>(); out->mutable_data<T>(context.GetPlace());
int context_start = context.Attr<int>("context_start"); int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length"); int context_length = context.Attr<int>("context_length");
bool padding_trainable = context.Attr<bool>("padding_trainable");
int context_stride = context.Attr<int>("context_stride"); int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable");
// InferShape by in_lod // InferShape by in_lod
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now."); "Only support one level sequence now.");
auto lod_level_0 = in->lod()[0];
const LoDTensor* padding_data = nullptr; const LoDTensor* padding_data = nullptr;
if (padding_trainable) { if (padding_trainable) {
...@@ -63,117 +56,51 @@ class SequenceProjectKernel : public framework::OpKernel<T> { ...@@ -63,117 +56,51 @@ class SequenceProjectKernel : public framework::OpKernel<T> {
int up_pad = std::max(0, -context_start); int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1); int down_pad = std::max(0, context_start + context_length - 1);
int sequence_height, sequence_width; int sequence_width;
int input_row_begin, input_row_end;
sequence_width = static_cast<int>(in->dims()[1]); sequence_width = static_cast<int>(in->dims()[1]);
paddle::operators::math::Im2ColFunctor< // use col_shape in the im2col calculation
paddle::operators::math::ColFormat::kOCF, Place, float> framework::DDim col_shape = {in->dims()[0],
im2col_ocf; sequence_width * context_length};
LoDTensor col;
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) { col.mutable_data<T>(col_shape, context.GetPlace());
input_row_begin = (context_start > 0) // Because if padding_trainable is false, padding data should be zeros.
? static_cast<int>(lod_level_0[i]) + context_start auto temp = framework::EigenVector<T>::Flatten(col);
: static_cast<int>(lod_level_0[i]); temp.device(context.GetEigenDevice<Place>()) =
input_row_end = static_cast<int>(lod_level_0[i + 1]); temp.constant(static_cast<T>(0));
Tensor out_t = out->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
std::vector<int64_t> output_shape(
{sequence_height, 1, 1, context_length,
sequence_width}); // output_height, output_width,
// input_channels, filter_height, filter_width
out_t.Resize(framework::make_ddim(output_shape));
if (input_row_begin < input_row_end) {
Tensor in_t = in->Slice(input_row_begin, input_row_end);
std::vector<int64_t> input_shape(
{1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context.device_context(), in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 0, up_pad,
down_pad);
}
if (padding_trainable) { paddle::operators::math::SequenceProjectFunctor<Place, T>
// add up trainable data seq_project_functor;
out_t.Resize(framework::make_ddim(
{sequence_height * context_length, sequence_width}));
if (up_pad > 0) { // add up pad seq_project_functor(context.device_context(), in, padding_data, &col,
int padding_rows = std::min( padding_trainable, context_start, context_length,
up_pad, static_cast<int>(lod_level_0[i + 1] - lod_level_0[i])); context_stride, up_pad, down_pad);
for (int k = 0; k < padding_rows; ++k) { filter.Resize(framework::make_ddim({context_length * sequence_width, 1}));
int padding_size = math::matmul<Place, T>(context.device_context(), col, false, filter, false,
k + context_length < up_pad ? context_length : up_pad - k; T(1.0), out, T(0.0));
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data->Slice(k, k + padding_size);
// in this block, using EigenVector<T>::Flatten is ok too.
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(place) = w_sub_e;
}
}
if (down_pad > 0) { // add down pad
int down_pad_begin_row =
std::max(0,
(sequence_height - context_start - context_length) + 1) +
1;
int padding_begin = std::max(0, context_start - sequence_height);
int padding_size =
sequence_height - context_start >= context_length
? 1
: context_length - (sequence_height - context_start);
if (context_start >= sequence_height) padding_size = context_length;
int padding_idx = padding_begin;
for (int t = 0; t + down_pad_begin_row <= sequence_height;
++t, ++padding_size) {
if (context_start >= sequence_height) padding_size = context_length;
if (padding_size > context_length) {
padding_size = context_length;
padding_idx++;
}
if (padding_begin > 0 || sequence_height == context_start)
padding_idx = padding_begin + t;
Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(place) = w_sub_e;
}
}
}
out_t.Resize(framework::make_ddim(
{sequence_height, context_length * sequence_width}));
}
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequenceProjectGradKernel : public framework::OpKernel<T> { class SequenceConvGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X")); auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* filter_g =
context.Output<LoDTensor>(framework::GradVarName("Filter"));
auto* padding_data_g = auto* padding_data_g =
context.Output<LoDTensor>(framework::GradVarName("PaddingData")); context.Output<LoDTensor>(framework::GradVarName("PaddingData"));
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* filter = context.Input<LoDTensor>("Filter");
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
int context_start = context.Attr<int>("context_start"); int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length"); int context_length = context.Attr<int>("context_length");
bool padding_trainable = context.Attr<bool>("padding_trainable");
int context_stride = context.Attr<int>("context_stride"); int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable");
// InferShape by in_lod // InferShape by in_lod
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
...@@ -187,15 +114,31 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -187,15 +114,31 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
sequence_width = static_cast<int>(in->dims()[1]); sequence_width = static_cast<int>(in->dims()[1]);
paddle::operators::math::Col2ImFunctor< // use col_shape in the im2col calculation
paddle::operators::math::ColFormat::kOCF, Place, float> framework::DDim col_shape = {in->dims()[0],
col2im_ocf; sequence_width * context_length};
LoDTensor col;
if (in_g || filter_g || (padding_trainable && padding_data_g)) {
col.mutable_data<T>(col_shape, context.GetPlace());
// Because if padding_trainable is false, padding data should be zeros.
auto temp = framework::EigenVector<T>::Flatten(col);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
math::matmul<Place, T>(context.device_context(), *out_g, false, *filter,
true, T(1.0), &col, T(1.0));
}
if (in_g) { if (in_g) {
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> functor; math::SetConstant<Place, T> functor;
functor(context.device_context(), in_g, 0); functor(context.device_context(), in_g, 0);
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
for (int i = 0; i < static_cast<int>(lod_g_level_0.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_g_level_0.size()) - 1; ++i) {
input_row_begin = input_row_begin =
(context_start > 0) (context_start > 0)
...@@ -203,10 +146,10 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -203,10 +146,10 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
: static_cast<int>(lod_g_level_0[i]); : static_cast<int>(lod_g_level_0[i]);
input_row_end = static_cast<int>(lod_g_level_0[i + 1]); input_row_end = static_cast<int>(lod_g_level_0[i + 1]);
Tensor out_g_t = out_g->Slice(static_cast<int>(lod_g_level_0[i]), Tensor col_t = col.Slice(static_cast<int>(lod_g_level_0[i]),
static_cast<int>(lod_g_level_0[i + 1])); static_cast<int>(lod_g_level_0[i + 1]));
sequence_height = static_cast<int>(out_g_t.dims()[0]); sequence_height = static_cast<int>(col_t.dims()[0]);
if (input_row_begin < input_row_end) { if (input_row_begin < input_row_end) {
Tensor in_t = in_g->Slice(input_row_begin, input_row_end); Tensor in_t = in_g->Slice(input_row_begin, input_row_end);
...@@ -214,19 +157,19 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -214,19 +157,19 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
std::vector<int64_t> output_shape( std::vector<int64_t> output_shape(
{sequence_height, 1, 1, context_length, {sequence_height, 1, 1, context_length,
sequence_width}); // output_height, output_width, sequence_width}); // output_height, output_width,
// input_channels, filter_height, filter_width // input_channels, filter_height, filter_width
out_g_t.Resize(framework::make_ddim(output_shape)); col_t.Resize(framework::make_ddim(output_shape));
std::vector<int64_t> input_shape( std::vector<int64_t> input_shape(
{1, input_row_end - input_row_begin, {1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context.device_context(), in_t, out_g_t, col2im_ocf(context.device_context(), in_t, col_t,
/*stride_height*/ context_stride, /*stride_width*/ 0, /*stride_height*/ context_stride, /*stride_width*/ 0,
up_pad, down_pad); up_pad, down_pad);
} }
out_g_t.Resize(framework::make_ddim( col_t.Resize(framework::make_ddim(
{sequence_height, context_length * sequence_width})); {sequence_height, context_length * sequence_width}));
} }
} }
...@@ -244,12 +187,12 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -244,12 +187,12 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
: static_cast<int>(lod_g_level_0[i]); : static_cast<int>(lod_g_level_0[i]);
input_row_end = static_cast<int>(lod_g_level_0[i + 1]); input_row_end = static_cast<int>(lod_g_level_0[i + 1]);
Tensor out_g_t = out_g->Slice(static_cast<int>(lod_g_level_0[i]), Tensor col_t = col.Slice(static_cast<int>(lod_g_level_0[i]),
static_cast<int>(lod_g_level_0[i + 1])); static_cast<int>(lod_g_level_0[i + 1]));
sequence_height = static_cast<int>(out_g_t.dims()[0]); sequence_height = static_cast<int>(col_t.dims()[0]);
out_g_t.Resize(framework::make_ddim( col_t.Resize(framework::make_ddim(
{sequence_height * context_length, sequence_width})); {sequence_height * context_length, sequence_width}));
if (up_pad > 0) { // add up pad if (up_pad > 0) { // add up pad
...@@ -260,8 +203,8 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -260,8 +203,8 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
for (int k = 0; k < padding_rows; ++k) { for (int k = 0; k < padding_rows; ++k) {
int padding_size = int padding_size =
k + context_length < up_pad ? context_length : up_pad - k; k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_g_t.Slice(k * context_length, Tensor out_t_sub = col_t.Slice(k * context_length,
k * context_length + padding_size); k * context_length + padding_size);
Tensor w_sub = padding_data_g->Slice(k, k + padding_size); Tensor w_sub = padding_data_g->Slice(k, k + padding_size);
// in this block, using EigenVector<T>::Flatten is ok too. // in this block, using EigenVector<T>::Flatten is ok too.
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub); auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
...@@ -290,7 +233,7 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -290,7 +233,7 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
} }
if (padding_begin > 0 || sequence_height == context_start) if (padding_begin > 0 || sequence_height == context_start)
padding_idx = padding_begin + t; padding_idx = padding_begin + t;
Tensor out_t_sub = out_g_t.Slice( Tensor out_t_sub = col_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size, (down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length); (down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data_g->Slice( Tensor w_sub = padding_data_g->Slice(
...@@ -300,10 +243,40 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -300,10 +243,40 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
w_sub_e.device(place) = w_sub_e + out_t_sub_e; w_sub_e.device(place) = w_sub_e + out_t_sub_e;
} }
} }
out_g_t.Resize(framework::make_ddim( col_t.Resize(framework::make_ddim(
{sequence_height, context_length * sequence_width})); {sequence_height, context_length * sequence_width}));
} }
} }
if (filter_g) {
filter_g->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> functor;
functor(context.device_context(), filter_g, 0);
Tensor filter_grad_ = *filter_g;
Tensor out_grad_ = *out_g;
const LoDTensor* padding_data = nullptr;
if (padding_trainable) {
padding_data = context.Input<LoDTensor>("PaddingData");
}
sequence_width = static_cast<int>(in->dims()[1]);
paddle::operators::math::SequenceProjectFunctor<Place, T>
seq_project_functor;
seq_project_functor(context.device_context(), in, padding_data, &col,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad);
filter_grad_.Resize(
framework::make_ddim({context_length * sequence_width, 1}));
math::matmul<Place, T>(context.device_context(), col, true, out_grad_,
false, T(1.0), &filter_grad_, T(1.0));
}
} }
}; };
......
import unittest
import numpy as np
import random
from op_test import OpTest
class TestSeqProject(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_project'
if self.context_length == 1 and self.context_start == 0 and self.padding_trainable:
print "If context_start is 0 and context_length is 1, padding_trainable should be false."
return
# one level, batch size
x = np.random.uniform(
0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32')
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max([0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
if self.total_pad == 0:
self.total_pad = 1
# PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() > 0 failed, 0 <= 0)
padding_data = np.random.uniform(
0.1, 1, [self.total_pad, self.input_size[1]]).astype('float32')
self.inputs = {
'X': (x, self.lod),
'PaddingData': (padding_data, [[0, self.total_pad]])
}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
self.outputs = {'Out': out}
self.compute()
def compute(self):
x, lod = self.inputs['X']
pading_data, _ = self.inputs['PaddingData']
out = self.outputs['Out']
lod = lod[0]
begin_pad = np.max([0, -self.context_start])
for i in range(len(lod) - 1):
for j in range(self.context_length):
in_begin = lod[i] + self.context_start + j
in_end = lod[i + 1] + self.context_start + j
out_begin = lod[i]
out_end = lod[i + 1]
if in_begin < lod[i]:
pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = pading_data[j:j + pad_size, :]
out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:(
j + 1) * self.input_size[1]] = sub_w
out_begin = lod[i] + pad_size
in_begin = lod[i]
if in_end > lod[i + 1]:
pad_size = np.min(
[in_end - lod[i + 1], lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = pading_data[begin_pad + self.context_start + j -
pad_size:begin_pad +
self.context_start + j, :]
out[lod[i + 1] - pad_size:lod[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
in_end = lod[i + 1]
out_end = lod[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.padding_trainable:
self.check_grad(
set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['X'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['PaddingData']))
def test_check_grad_no_input(self):
if self.padding_trainable:
self.check_grad(
['PaddingData'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['X']))
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11
self.context_start = 0
self.context_length = 1
self.padding_trainable = False
self.context_stride = 1
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11
self.context_start = -1
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 25
self.context_start = 2
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1
self.input_size = [self.input_row, 23]
idx = range(self.input_size[0])
del idx[0]
self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
[self.input_size[0]]]
'''
class TestSeqProjectCases(TestSeqProject):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_project'
num = 0
for context_start in [-5, -3, -1, 0, 3]:
for context_length in [1, 2, 5, 7]:
for batch_size in [1, 2, 5, 7]:
for padding_trainable in [False, True]:
if context_length == 1 and context_start == 0 and padding_trainable:
continue
self.context_start = context_start
self.context_length = context_length
self.padding_trainable = padding_trainable
self.input_size = [batch_size, 23]
x = np.random.uniform(0.1, 1,
self.input_size).astype('float32')
self.lod = [[0, self.input_size[0]]]
if self.input_size[0] > 2:
idx = range(self.input_size[0])
del idx[0]
self.lod = [
[0] + np.sort(random.sample(idx, 2)).tolist() +
[self.input_size[0]]
]
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max([0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
if self.total_pad == 0:
self.total_pad = 1
# PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() > 0 failed, 0 <= 0)
padding_data = np.random.uniform(
0.1, 1, [self.total_pad, self.input_size[1]]).astype('float32')
self.inputs = {
'X': (x, self.lod),
'PaddingData': (padding_data, [[0, self.total_pad]])
}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
self.outputs = {'Out': out}
print num
print self.attrs
print batch_size
print padding_trainable
print "$$$$$$$$$$$$$"
self.compute()
self.test_check_output()
num += 1
'''
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册