sequence_conv_op.h 6.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#include <algorithm>
Y
Yi Wang 已提交
17 18 19
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/context_project.h"
#include "paddle/fluid/operators/math/math_function.h"
C
chengduoZH 已提交
20 21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

Q
QI JUN 已提交
27
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
28
class SequenceConvKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
29 30 31 32
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>("X");
    auto* out = context.Output<LoDTensor>("Out");
C
chengduoZH 已提交
33
    auto filter = *context.Input<Tensor>("Filter");
34

C
chengduoZH 已提交
35
    out->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
36

C
chengduoZH 已提交
37 38 39 40
    int context_start = context.Attr<int>("contextStart");
    int context_length = context.Attr<int>("contextLength");
    int context_stride = context.Attr<int>("contextStride");
    bool padding_trainable = context.Attr<bool>("paddingTrainable");
C
chengduoZH 已提交
41 42 43 44

    PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
                      "Only support one level sequence now.");

C
chengduoZH 已提交
45
    const Tensor* padding_data = nullptr;
C
chengduoZH 已提交
46
    if (padding_trainable) {
C
chengduoZH 已提交
47
      padding_data = context.Input<Tensor>("PaddingData");
C
chengduoZH 已提交
48 49 50 51
    }

    int up_pad = std::max(0, -context_start);
    int down_pad = std::max(0, context_start + context_length - 1);
52
    auto sequence_width = static_cast<int64_t>(in->dims()[1]);
C
chengduoZH 已提交
53

C
chengduoZH 已提交
54
    framework::DDim col_shape = {in->dims()[0],
C
chengduoZH 已提交
55
                                 context_length * sequence_width};
C
chengduoZH 已提交
56
    Tensor col;
C
chengduoZH 已提交
57 58
    col.mutable_data<T>(col_shape, context.GetPlace());
    // Because if padding_trainable is false, padding data should be zeros.
Q
QI JUN 已提交
59 60
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
61
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
62 63
    set_zero(dev_ctx, &col, static_cast<T>(0));
    math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
64

65
    seq_project_functor(dev_ctx, *in, padding_data, padding_trainable,
Q
QI JUN 已提交
66 67
                        context_start, context_length, context_stride, up_pad,
                        down_pad, &col);
68

Y
Yu Yang 已提交
69
    blas.MatMul(col, filter, out);
C
chengduoZH 已提交
70 71 72
  }
};

Q
QI JUN 已提交
73
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
74
class SequenceConvGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
75 76 77
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
C
chengduoZH 已提交
78
    auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
79
    auto* filter_g = context.Output<Tensor>(framework::GradVarName("Filter"));
C
chengduoZH 已提交
80
    auto* padding_data_g =
C
chengduoZH 已提交
81
        context.Output<Tensor>(framework::GradVarName("PaddingData"));
82
    auto* in = context.Input<LoDTensor>("X");
C
chengduoZH 已提交
83
    auto* filter = context.Input<Tensor>("Filter");
C
chengduoZH 已提交
84

C
chengduoZH 已提交
85 86 87 88
    int context_start = context.Attr<int>("contextStart");
    int context_length = context.Attr<int>("contextLength");
    int context_stride = context.Attr<int>("contextStride");
    bool padding_trainable = context.Attr<bool>("paddingTrainable");
C
chengduoZH 已提交
89

90
    PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
C
chengduoZH 已提交
91
                      "Only support one level sequence now.");
92
    auto lod_g_level_0 = in->lod()[0];
C
chengduoZH 已提交
93

C
chengduoZH 已提交
94 95
    int up_pad = std::max(0, -context_start);
    int down_pad = std::max(0, context_start + context_length - 1);
96
    auto sequence_width = static_cast<int64_t>(in->dims()[1]);
C
chengduoZH 已提交
97

Q
QI JUN 已提交
98 99
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
100
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
101 102 103
    // use col_shape in the im2col calculation
    framework::DDim col_shape = {in->dims()[0],
                                 sequence_width * context_length};
C
chengduoZH 已提交
104
    Tensor col;
C
chengduoZH 已提交
105 106 107 108

    if (in_g || filter_g || (padding_trainable && padding_data_g)) {
      col.mutable_data<T>(col_shape, context.GetPlace());
      // Because if padding_trainable is false, padding data should be zeros.
Q
QI JUN 已提交
109
      set_zero(dev_ctx, &col, static_cast<T>(0));
Y
Yu Yang 已提交
110
      blas.MatMul(*out_g, false, *filter, true, &col);
C
chengduoZH 已提交
111
    }
Q
QI JUN 已提交
112 113
    math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
    math::ContextProjectGradFunctor<DeviceContext, T> seq_project_grad_functor;
C
chengduoZH 已提交
114

C
chengduoZH 已提交
115 116
    if (in_g) {
      in_g->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
117
      in_g->set_lod(in->lod());
Q
QI JUN 已提交
118
      set_zero(dev_ctx, in_g, static_cast<T>(0));
119

Q
QI JUN 已提交
120 121 122
      seq_project_grad_functor(dev_ctx, *in_g, padding_trainable, context_start,
                               context_length, context_stride, up_pad, down_pad,
                               false, true, padding_data_g, &col);
C
chengduoZH 已提交
123 124 125 126
    }

    if (padding_trainable && padding_data_g) {
      padding_data_g->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
127
      set_zero(dev_ctx, padding_data_g, static_cast<T>(0));
C
chengduoZH 已提交
128

C
chengduoZH 已提交
129
      LoDTensor* input = const_cast<LoDTensor*>(in);
Q
QI JUN 已提交
130 131 132
      seq_project_grad_functor(
          dev_ctx, *input, padding_trainable, context_start, context_length,
          context_stride, up_pad, down_pad, true, false, padding_data_g, &col);
C
chengduoZH 已提交
133
    }
C
chengduoZH 已提交
134 135 136

    if (filter_g) {
      filter_g->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
137
      set_zero(dev_ctx, filter_g, static_cast<T>(0));
C
chengduoZH 已提交
138

C
chengduoZH 已提交
139 140
      Tensor filter_grad = *filter_g;
      LoDTensor out_grad = *out_g;
C
chengduoZH 已提交
141

C
chengduoZH 已提交
142
      const Tensor* padding_data = nullptr;
C
chengduoZH 已提交
143
      if (padding_trainable) {
C
chengduoZH 已提交
144
        padding_data = context.Input<Tensor>("PaddingData");
C
chengduoZH 已提交
145 146
      }

147
      seq_project_functor(dev_ctx, *in, padding_data, padding_trainable,
Q
QI JUN 已提交
148 149
                          context_start, context_length, context_stride, up_pad,
                          down_pad, &col);
C
chengduoZH 已提交
150

Y
Yu Yang 已提交
151
      blas.MatMul(col, true, out_grad, false, &filter_grad);
C
chengduoZH 已提交
152
    }
C
chengduoZH 已提交
153 154 155 156 157
  }
};

}  // namespace operators
}  // namespace paddle