cos_sim_op.h 5.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cos_sim_functor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
X
Xinghai Sun 已提交
20 21 22 23 24 25

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

Q
QI JUN 已提交
26
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
27
class CosSimKernel : public framework::OpKernel<T> {
X
Xinghai Sun 已提交
28 29
 public:
  void Compute(const framework::ExecutionContext& context) const override {
30
    // get Tensor
L
luotao1 已提交
31
    auto* in_x = context.Input<framework::LoDTensor>("X");
32
    auto* in_y = context.Input<Tensor>("Y");
L
luotao1 已提交
33
    auto* out_z = context.Output<framework::LoDTensor>("Out");
34 35
    auto* out_x_norm = context.Output<Tensor>("XNorm");
    auto* out_y_norm = context.Output<Tensor>("YNorm");
X
Xinghai Sun 已提交
36

37 38
    int rows_x = in_x->dims()[0];
    int rows_y = in_y->dims()[0];
L
luotao1 已提交
39 40 41 42 43 44 45
    out_z->Resize({rows_x, 1});
    out_x_norm->Resize({rows_x, 1});
    out_y_norm->Resize({rows_y, 1});
    out_z->mutable_data<T>(context.GetPlace());
    out_x_norm->mutable_data<T>(context.GetPlace());
    out_y_norm->mutable_data<T>(context.GetPlace());
    out_z->set_lod(in_x->lod());
C
chengduoZH 已提交
46 47

    int cols = framework::product(in_x->dims()) / rows_x;
C
chengduoZH 已提交
48 49

    if (rows_x == rows_y) {
C
chengduoZH 已提交
50
      math::CosSimFunctor<T, true> functor(
C
chengduoZH 已提交
51 52
          in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
          out_y_norm->data<T>(), out_z->data<T>(), cols);
53 54 55
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(context.device_context()), rows_x);
      for_range(functor);
C
chengduoZH 已提交
56
    } else {
C
chengduoZH 已提交
57
      math::CosSimFunctor<T, false> functor(
C
chengduoZH 已提交
58 59
          in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
          out_y_norm->data<T>(), out_z->data<T>(), cols);
60 61 62
      platform::ForRange<DeviceContext> for_range(
          static_cast<const DeviceContext&>(context.device_context()), rows_x);
      for_range(functor);
C
chengduoZH 已提交
63
    }
X
Xinghai Sun 已提交
64 65 66
  }
};

Q
QI JUN 已提交
67
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
68
class CosSimGradKernel : public framework::OpKernel<T> {
X
Xinghai Sun 已提交
69 70
 public:
  void Compute(const framework::ExecutionContext& context) const override {
71 72 73 74 75 76 77 78 79
    // get Tensor
    auto* in_x = context.Input<Tensor>("X");
    auto* in_y = context.Input<Tensor>("Y");
    auto* in_z = context.Input<Tensor>("Out");
    auto* in_x_norm = context.Input<Tensor>("XNorm");
    auto* in_y_norm = context.Input<Tensor>("YNorm");
    auto* out_grad_x = context.Output<Tensor>(framework::GradVarName("X"));
    auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
    auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
X
Xinghai Sun 已提交
80

81
    // compute gradident
82 83 84
    int rows_x = in_x->dims()[0];
    int rows_y = in_y->dims()[0];
    int cols = framework::product(in_x->dims()) / rows_x;
C
chengduoZH 已提交
85

C
chengduoZH 已提交
86 87
    if (rows_x == rows_y) {
      if (out_grad_x) {
L
luotao1 已提交
88
        out_grad_x->Resize(in_x->dims());
C
chengduoZH 已提交
89
        math::CosSimGradFunctor<T> functor(
C
chengduoZH 已提交
90 91 92
            in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
            in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
            out_grad_x->mutable_data<T>(context.GetPlace()), cols);
93 94 95 96
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(context.device_context()),
            rows_x);
        for_range(functor);
C
chengduoZH 已提交
97 98
      }
      if (out_grad_y) {
L
luotao1 已提交
99
        out_grad_y->Resize(in_y->dims());
C
chengduoZH 已提交
100
        math::CosSimGradFunctor<T> functor(
C
chengduoZH 已提交
101 102 103
            in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
            in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
            out_grad_y->mutable_data<T>(context.GetPlace()), cols);
104 105 106 107
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(context.device_context()),
            rows_x);
        for_range(functor);
C
chengduoZH 已提交
108 109 110
      }
    } else {
      if (out_grad_x) {
L
luotao1 已提交
111
        out_grad_x->Resize(in_x->dims());
C
chengduoZH 已提交
112
        math::CosSimDxFunctor<T> functor(
C
chengduoZH 已提交
113 114
            in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
            in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
C
refine  
chengduoZH 已提交
115
            out_grad_x->mutable_data<T>(context.GetPlace()), cols);
116 117 118 119
        platform::ForRange<DeviceContext> for_range(
            static_cast<const DeviceContext&>(context.device_context()),
            rows_x);
        for_range(functor);
C
chengduoZH 已提交
120 121
      }
      if (out_grad_y) {
L
luotao1 已提交
122
        out_grad_y->Resize(in_y->dims());
C
refine  
chengduoZH 已提交
123 124 125 126 127
        out_grad_y->mutable_data<T>(context.GetPlace());
        math::SetConstant<DeviceContext, T> set_zero;
        auto& dev_ctx = context.template device_context<DeviceContext>();
        set_zero(dev_ctx, out_grad_y, static_cast<T>(0));

C
chengduoZH 已提交
128
        math::CosSimDyFunctor<DeviceContext, T> functor;
C
chengduoZH 已提交
129 130 131 132
        functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
                in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
                in_grad_z->data<T>(), static_cast<size_t>(rows_x),
                static_cast<size_t>(cols), out_grad_y->data<T>());
C
chengduoZH 已提交
133
      }
134
    }
X
Xinghai Sun 已提交
135 136 137 138 139
  }
};

}  // namespace operators
}  // namespace paddle