fusion_seqpool_concat_op.cc 6.3 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/fused/fusion_seqpool_concat_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"

namespace paddle {
namespace operators {

void FusionSeqPoolConcatOp::InferShape(
    framework::InferShapeContext* ctx) const {
  PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
26 27 28 29 30
                    platform::errors::InvalidArgument(
                        "Inputs(X) of FusionSeqPoolConcatOp should be greated "
                        "than 1, but received value is %d.",
                        ctx->Inputs("X").size()));
  OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FusionSeqPoolConcat");
T
tensor-tang 已提交
31
  int axis = ctx->Attrs().Get<int>("axis");
32 33 34 35
  PADDLE_ENFORCE_EQ(axis, 1, platform::errors::InvalidArgument(
                                 "FusionSeqPoolConcatOp only supports concat "
                                 "axis=1 yet, but received axis value is %d",
                                 axis));
T
tensor-tang 已提交
36 37 38

  auto ins_dims = ctx->GetInputsDim("X");
  const size_t n = ins_dims.size();
39 40 41 42
  PADDLE_ENFORCE_GT(n, 0UL, platform::errors::InvalidArgument(
                                "Input tensors count should be greater than 0, "
                                "but received value is %d.",
                                n));
T
tensor-tang 已提交
43 44 45 46 47 48
  if (n == 1) {
    LOG(WARNING) << "Only have one input, may waste memory";
  }

  // The output height should be confirmed in Compute,
  // since input lod is not accessible here.
T
tensor-tang 已提交
49
  PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2,
50 51 52 53
                    platform::errors::InvalidArgument(
                        "The dims size of first input should be equal to 2, "
                        "but received value is %d.",
                        ins_dims[0].size()));
T
tensor-tang 已提交
54
  ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
55 56 57 58 59 60

  if (!ctx->IsRuntime()) {
    // when compiling, the LodLevel of Out is set to be 1, which is consistent
    // with that in running time.
    ctx->SetLoDLevel("Out", 1);
  }
T
tensor-tang 已提交
61 62 63 64 65
}

framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
  return framework::OpKernelType(
66
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
T
tensor-tang 已提交
67 68 69 70 71 72
}

void FusionSeqPoolConcatOpMaker::Make() {
  AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
  AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
  AddAttr<std::string>("pooltype",
T
tensor-tang 已提交
73
                       "(string, default 'SUM') some of the pooling "
T
tensor-tang 已提交
74 75 76 77
                       "pooltype of SequencePoolOp.")
      .SetDefault("SUM")
      .InEnum({"AVERAGE", "SUM", "SQRT"});
  AddAttr<int>("axis",
T
tensor-tang 已提交
78 79
               "The axis along which the input tensors will be concatenated. "
               "Only supports concat axis=1 yet.")
T
tensor-tang 已提交
80 81 82 83 84 85 86 87 88 89 90 91
      .SetDefault(1);
  AddComment(R"DOC(
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
)DOC");
}

template <typename T>
class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<LoDTensor>("X");
    auto* out = ctx.Output<LoDTensor>("Out");
T
tensor-tang 已提交
92
    std::string pooltype = ctx.Attr<std::string>("pooltype");
T
tensor-tang 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    auto x0_lod = ins[0]->lod();
    auto x0_dims = ins[0]->dims();
    auto y_dims = out->dims();
    size_t bs = x0_lod[0].size() - 1;
    out->Resize({static_cast<int64_t>(bs), y_dims[1]});
    framework::LoD y_lod(1);
    y_lod[0].resize(bs + 1);
    for (size_t i = 0; i <= bs; ++i) {
      y_lod[0][i] = i;
    }
    out->set_lod(y_lod);
    auto place = ctx.GetPlace();
    T* y_data = out->mutable_data<T>(place);

    int w = ins[0]->numel() / x0_dims[0];
    PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
109 110 111 112
                      platform::errors::InvalidArgument(
                          "The output of dims[1] should be dividable of w, but "
                          "dims[1] is %d, w is %d.",
                          y_dims[1], w));
T
tensor-tang 已提交
113
    jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
T
tensor-tang 已提交
114 115 116 117 118
    if (pooltype == "AVERAGE") {
      attr.type = jit::SeqPoolType::kAvg;
    } else if (pooltype == "SQRT") {
      attr.type = jit::SeqPoolType::kSqrt;
    }
119 120 121
    auto seqpool =
        jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
            attr);
T
tensor-tang 已提交
122
    size_t n = ins.size();
T
tensor-tang 已提交
123
    size_t dst_step_size = n * w;
T
tensor-tang 已提交
124 125 126 127 128
    for (size_t i = 0; i < n; ++i) {
      auto x_dims = ins[i]->dims();
      auto x_lod = ins[i]->lod()[0];
      const T* src = ins[i]->data<T>();
      T* dst = y_data + i * w;
129 130 131 132 133 134 135 136 137 138 139 140
      PADDLE_ENFORCE_EQ(
          static_cast<int>(ins[i]->numel() / x_dims[0]), w,
          platform::errors::InvalidArgument(
              "Width of all inputs should be equal, but the width of the %d-th "
              "input %d is not equal to the previous %d",
              i, static_cast<int>(ins[i]->numel() / x_dims[0]), w));
      PADDLE_ENFORCE_EQ(
          x_lod.size(), bs + 1,
          platform::errors::InvalidArgument(
              "Batchsize of all inputs should be equal, but the value of the "
              "%d-th %d is not equal to the previous %d.",
              i, x_lod.size(), bs + 1));
T
tensor-tang 已提交
141 142 143
      for (size_t j = 0; j < bs; ++j) {
        attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
        seqpool(src, dst, &attr);
T
tensor-tang 已提交
144
        dst += dst_step_size;
T
tensor-tang 已提交
145 146 147 148 149 150 151 152 153 154 155
        src += attr.h * attr.w;
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqpool_concat, ops::FusionSeqPoolConcatOp,
156
                  ops::FusionSeqPoolConcatOpMaker);
T
tensor-tang 已提交
157 158 159 160

REGISTER_OP_CPU_KERNEL(fusion_seqpool_concat,
                       ops::FusionSeqPoolConcatKernel<float>,
                       ops::FusionSeqPoolConcatKernel<double>);