partial_sum_op.h 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <utility>
#include <vector>
17

18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class PartialSumKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<Tensor>("X");
    Tensor* out = ctx.Output<Tensor>("Out");
    PADDLE_ENFORCE_EQ(
32 33
        ins[0] != nullptr,
        true,
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        platform::errors::InvalidArgument("The input should not be null."));

    auto place = ctx.GetPlace();  // CPUPlace only now

    auto* out_t = out->mutable_data<T>(place);
    auto start_index = ctx.Attr<int>("start_index");
    auto length = ctx.Attr<int>("length");
    auto batch_size = ins[0]->dims()[0];
    if (length == -1) {
      length = ins[0]->dims()[1] - start_index;
    }

    memset(out_t, 0, sizeof(T) * batch_size * length);

    for (size_t i = 0; i < ins.size(); ++i) {
      auto* in_t = ins[i]->data<T>();
      auto total_len = ins[i]->dims()[1];
      for (auto bs_id = 0; bs_id < batch_size; ++bs_id) {
        for (auto k = 0; k < length; ++k) {
          out_t[bs_id * length + k] +=
              in_t[bs_id * total_len + start_index + k];
        }
      }
    }
  }
};

template <typename T>
class PartialSumGradientOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto ins = ctx.MultiInput<framework::LoDTensor>("X");
    auto outs =
        ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));

    PADDLE_ENFORCE_EQ(
71 72
        ins[0] != nullptr,
        true,
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        platform::errors::InvalidArgument("The input should not be null."));
    auto start_index = ctx.Attr<int>("start_index");
    auto length = ctx.Attr<int>("length");
    auto batch_size = ins[0]->dims()[0];
    if (length == -1) {
      length = ins[0]->dims()[1] - start_index;
    }

    // initialize
    auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
    for (size_t i = 0; i < outs.size(); ++i) {
      outs[i]->mutable_data<T>(ctx.GetPlace());
      auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
      dxt.device(place) = dxt.constant(static_cast<T>(0));
    }

    auto* out_grad_t = out_grad->data<T>();
    for (size_t i = 0; i < outs.size(); ++i) {
      auto* out_t = outs[i]->data<T>();
      auto total_len = ins[i]->dims()[1];
      for (auto bs_id = 0; bs_id < batch_size; ++bs_id) {
        for (int len = 0; len < length; ++len) {
          out_t[start_index + bs_id * total_len + len] =
              out_grad_t[bs_id * length + len] * static_cast<T>(1);
        }
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle