sequence_conv_op.h 6.5 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
C
chengduoZH 已提交
18
#include "paddle/operators/math/context_project.h"
C
chengduoZH 已提交
19
#include "paddle/operators/math/math_function.h"
C
chengduoZH 已提交
20 21 22 23 24 25 26 27

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

template <typename Place, typename T>
C
chengduoZH 已提交
28
class SequenceConvKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
29 30 31 32
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>("X");
    auto* out = context.Output<LoDTensor>("Out");
C
chengduoZH 已提交
33
    auto filter = *context.Input<Tensor>("Filter");
34

C
chengduoZH 已提交
35
    out->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
36
    context.ShareLoD("X", "Out");
C
chengduoZH 已提交
37

C
chengduoZH 已提交
38 39 40 41
    int context_start = context.Attr<int>("contextStart");
    int context_length = context.Attr<int>("contextLength");
    int context_stride = context.Attr<int>("contextStride");
    bool padding_trainable = context.Attr<bool>("paddingTrainable");
C
chengduoZH 已提交
42 43 44 45 46

    // InferShape by in_lod
    PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
                      "Only support one level sequence now.");

C
chengduoZH 已提交
47
    const Tensor* padding_data = nullptr;
C
chengduoZH 已提交
48
    if (padding_trainable) {
C
chengduoZH 已提交
49
      padding_data = context.Input<Tensor>("PaddingData");
C
chengduoZH 已提交
50 51 52 53
    }

    int up_pad = std::max(0, -context_start);
    int down_pad = std::max(0, context_start + context_length - 1);
C
chengduoZH 已提交
54
    int sequence_width;
C
chengduoZH 已提交
55
    sequence_width = static_cast<int>(in->dims()[1]);
C
chengduoZH 已提交
56

C
chengduoZH 已提交
57
    // Use col_shape in the im2col calculation.
C
chengduoZH 已提交
58 59
    framework::DDim col_shape = {in->dims()[0],
                                 sequence_width * context_length};
C
chengduoZH 已提交
60
    Tensor col;
C
chengduoZH 已提交
61
    col.mutable_data<T>(col_shape, context.GetPlace());
C
chengduoZH 已提交
62
    math::SetConstant<Place, T> set_zero;
C
chengduoZH 已提交
63
    // Because if padding_trainable is false, padding data should be zeros.
C
chengduoZH 已提交
64
    set_zero(context.device_context(), &col, static_cast<T>(0));
65

C
chengduoZH 已提交
66
    paddle::operators::math::ContextProjectFunctor<Place, T>
C
chengduoZH 已提交
67
        seq_project_functor;
68

C
sss  
chengduoZH 已提交
69
    seq_project_functor(context.device_context(), *in, *padding_data, col,
C
chengduoZH 已提交
70
                        padding_trainable, context_start, context_length,
C
sss  
chengduoZH 已提交
71
                        context_stride, up_pad, down_pad);
72

C
chengduoZH 已提交
73
    math::matmul<Place, T>(context.device_context(), col, false, filter, false,
C
chengduoZH 已提交
74
                           static_cast<T>(1.0), out, static_cast<T>(0.0));
C
chengduoZH 已提交
75 76 77 78
  }
};

template <typename Place, typename T>
C
chengduoZH 已提交
79
class SequenceConvGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
80 81 82 83
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
    auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
C
chengduoZH 已提交
84
    auto* filter_g = context.Output<Tensor>(framework::GradVarName("Filter"));
C
chengduoZH 已提交
85
    auto* padding_data_g =
C
chengduoZH 已提交
86
        context.Output<Tensor>(framework::GradVarName("PaddingData"));
87
    auto* in = context.Input<LoDTensor>("X");
C
chengduoZH 已提交
88
    auto* filter = context.Input<Tensor>("Filter");
C
chengduoZH 已提交
89

C
chengduoZH 已提交
90 91 92 93
    int context_start = context.Attr<int>("contextStart");
    int context_length = context.Attr<int>("contextLength");
    int context_stride = context.Attr<int>("contextStride");
    bool padding_trainable = context.Attr<bool>("paddingTrainable");
C
chengduoZH 已提交
94

95
    PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
C
chengduoZH 已提交
96
                      "Only support one level sequence now.");
97
    auto lod_g_level_0 = in->lod()[0];
C
chengduoZH 已提交
98

C
chengduoZH 已提交
99 100
    int up_pad = std::max(0, -context_start);
    int down_pad = std::max(0, context_start + context_length - 1);
C
chengduoZH 已提交
101
    int sequence_width = static_cast<int>(in->dims()[1]);
C
chengduoZH 已提交
102

C
chengduoZH 已提交
103
    math::SetConstant<Place, T> set_zero;
C
chengduoZH 已提交
104 105 106
    // use col_shape in the im2col calculation
    framework::DDim col_shape = {in->dims()[0],
                                 sequence_width * context_length};
C
chengduoZH 已提交
107
    Tensor col;
C
chengduoZH 已提交
108 109 110 111

    if (in_g || filter_g || (padding_trainable && padding_data_g)) {
      col.mutable_data<T>(col_shape, context.GetPlace());
      // Because if padding_trainable is false, padding data should be zeros.
C
chengduoZH 已提交
112
      set_zero(context.device_context(), &col, static_cast<T>(0));
C
chengduoZH 已提交
113 114 115
      math::matmul<Place, T>(context.device_context(), *out_g, false, *filter,
                             true, T(1.0), &col, T(1.0));
    }
C
chengduoZH 已提交
116
    paddle::operators::math::ContextProjectFunctor<Place, T>
C
chengduoZH 已提交
117
        seq_project_functor;
C
sss  
chengduoZH 已提交
118 119
    paddle::operators::math::ContextProjectGradFunctor<Place, T>
        seq_project_grad_functor;
C
chengduoZH 已提交
120

C
chengduoZH 已提交
121 122
    if (in_g) {
      in_g->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
123
      in_g->set_lod(in->lod());
C
chengduoZH 已提交
124
      set_zero(context.device_context(), in_g, static_cast<T>(0));
125

C
sss  
chengduoZH 已提交
126 127 128 129
      seq_project_grad_functor(context.device_context(), *in_g, *padding_data_g,
                               col, padding_trainable, context_start,
                               context_length, context_stride, up_pad, down_pad,
                               true, false);
C
chengduoZH 已提交
130 131 132 133
    }

    if (padding_trainable && padding_data_g) {
      padding_data_g->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
134
      set_zero(context.device_context(), padding_data_g, static_cast<T>(0));
C
chengduoZH 已提交
135

C
chengduoZH 已提交
136
      LoDTensor* input = const_cast<LoDTensor*>(in);
C
sss  
chengduoZH 已提交
137 138 139 140
      seq_project_grad_functor(context.device_context(), *input,
                               *padding_data_g, col, padding_trainable,
                               context_start, context_length, context_stride,
                               up_pad, down_pad, false, true);
C
chengduoZH 已提交
141
    }
C
chengduoZH 已提交
142 143 144

    if (filter_g) {
      filter_g->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
145
      set_zero(context.device_context(), filter_g, static_cast<T>(0));
C
chengduoZH 已提交
146

C
chengduoZH 已提交
147 148
      Tensor filter_grad = *filter_g;
      LoDTensor out_grad = *out_g;
C
chengduoZH 已提交
149

C
chengduoZH 已提交
150
      const Tensor* padding_data = nullptr;
C
chengduoZH 已提交
151
      if (padding_trainable) {
C
chengduoZH 已提交
152
        padding_data = context.Input<Tensor>("PaddingData");
C
chengduoZH 已提交
153 154
      }

C
sss  
chengduoZH 已提交
155
      seq_project_functor(context.device_context(), *in, *padding_data, col,
C
chengduoZH 已提交
156
                          padding_trainable, context_start, context_length,
C
sss  
chengduoZH 已提交
157
                          context_stride, up_pad, down_pad);
C
chengduoZH 已提交
158

C
chengduoZH 已提交
159 160
      math::matmul<Place, T>(context.device_context(), col, true, out_grad,
                             false, T(1.0), &filter_grad, T(1.0));
C
chengduoZH 已提交
161
    }
C
chengduoZH 已提交
162 163 164 165 166
  }
};

}  // namespace operators
}  // namespace paddle