target_assign_op.h 4.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
16 17
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
18 19 20

namespace paddle {
namespace operators {
21
template <typename T, typename WT>
22
struct TargetAssignFunctor {
23
  const T* in_;
24 25
  const int* match_indices_;
  const size_t* lod_;
26 27 28 29 30 31 32 33 34 35 36 37 38 39
  const int mismatch_value_;
  const int64_t N_;
  const int64_t M_;
  const int64_t P_;
  const int64_t K_;

  T* out_;
  WT* out_wt_;

  TargetAssignFunctor(const T* input, const int* match_indices,
                      const size_t* lod, const int mismatch_value,
                      const int64_t N, const int64_t M, const int64_t P,
                      const int64_t K, T* out, WT* out_wt)
      : in_(input),
40 41
        match_indices_(match_indices),
        lod_(lod),
42 43 44 45 46 47 48
        mismatch_value_(mismatch_value),
        N_(N),
        M_(M),
        P_(P),
        K_(K),
        out_(out),
        out_wt_(out_wt) {}
49 50

  HOSTDEVICE void operator()(size_t i) const {
51 52
    int h = i / M_;
    int w = i - h * M_;
53

54 55
    size_t off = lod_[h];
    int id = match_indices_[i];
56

57 58
    T* out = out_ + i * K_;
    WT* out_wt = out_wt_ + i;
59 60

    if (id > -1) {
61 62 63 64 65 66
      int w_off = w % P_;
      const T* in = in_ + ((off + id) * P_ + w_off) * K_;
      for (int64_t k = 0; k < K_; ++k) {
        out[k] = in[k];
      }
      out_wt[0] = static_cast<WT>(1.);
67
    } else {
68 69 70 71
      for (int64_t k = 0; k < K_; ++k) {
        out[k] = static_cast<T>(mismatch_value_);
      }
      out_wt[0] = static_cast<WT>(0.);
72 73 74 75
    }
  }
};

76
template <typename DeviceContext, typename T, typename WT>
D
dangqingqing 已提交
77
struct NegTargetAssignFunctor {
78
  void operator()(const platform::DeviceContext& ctx, const int* neg_indices,
79 80
                  const size_t* lod, const int N, const int M, const int K,
                  const int mismatch_value, T* out, WT* out_wt) const;
81 82
};

83
template <typename DeviceContext, typename T, typename WT>
84 85 86
class TargetAssignKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
87
    auto* x = ctx.Input<framework::LoDTensor>("X");
88 89
    auto* match_indices = ctx.Input<framework::Tensor>("MatchIndices");

90 91
    auto* out = ctx.Output<framework::Tensor>("Out");
    auto* out_wt = ctx.Output<framework::Tensor>("OutWeight");
92

93 94 95
    PADDLE_ENFORCE_EQ(x->lod().size(), 1UL,
                      platform::errors::InvalidArgument(
                          "TargetAssignOp input(X) needs 1 level of LoD"));
96
    int mismatch_value = ctx.Attr<int>("mismatch_value");
97

98
    const T* x_data = x->data<T>();
99 100
    const int* match_idx_data = match_indices->data<int>();

101 102
    T* out_data = out->mutable_data<T>(ctx.GetPlace());
    WT* out_wt_data = out_wt->mutable_data<WT>(ctx.GetPlace());
103

104 105 106 107
    int64_t n = match_indices->dims()[0];
    int64_t m = match_indices->dims()[1];
    int64_t p = x->dims()[1];
    int64_t k = x->dims()[2];
108

109
    auto x_lod = x->lod().back();
110
#if defined(PADDLE_WITH_CUDA)
111
    size_t* x_lod_data = x_lod.MutableData(ctx.GetPlace());
112 113 114
#else
    size_t* x_lod_data = x_lod.data();
#endif
115

116 117 118
    TargetAssignFunctor<T, WT> functor(x_data, match_idx_data, x_lod_data,
                                       mismatch_value, n, m, p, k, out_data,
                                       out_wt_data);
119 120

    auto& device_ctx = ctx.template device_context<DeviceContext>();
121
    platform::ForRange<DeviceContext> for_range(device_ctx, n * m);
122 123
    for_range(functor);

124 125
    auto* neg_indices = ctx.Input<framework::LoDTensor>("NegIndices");
    if (neg_indices) {
126 127 128 129
      PADDLE_ENFORCE_EQ(
          neg_indices->lod().size(), 1UL,
          platform::errors::InvalidArgument(
              "TargetAssignOp input(NegIndices) needs 1 level of LoD"));
130 131
      const int* neg_idx_data = neg_indices->data<int>();
      auto neg_lod = neg_indices->lod().back();
132
#if defined(PADDLE_WITH_CUDA)
133
      size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace());
134 135 136
#else
      size_t* neg_lod_data = neg_lod.data();
#endif
137 138 139 140
      NegTargetAssignFunctor<DeviceContext, T, WT> neg_trg_functor;
      neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, n, m, k,
                      mismatch_value, out_data, out_wt_data);
    }
141 142 143 144 145
  }
};

}  // namespace operators
}  // namespace paddle