center_loss_op.cc 6.1 KB
Newer Older
H
HaoRen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/center_loss_op.h"
#include <memory>
#include <string>

namespace paddle {
namespace operators {
class CenterLossOp : public framework::OperatorWithKernel {
 public:
  CenterLossOp(const std::string &type,
               const framework::VariableNameMap &inputs,
               const framework::VariableNameMap &outputs,
               const framework::AttributeMap &attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of CenterLoss should not be null.");
    auto x_dims = ctx->GetInputDim("X");

    PADDLE_ENFORCE(ctx->HasInput("CenterUpdateRate"),
                   "Input(CenterUpdateRate) of CenterLoss should not be null.");

    PADDLE_ENFORCE(ctx->HasInput("Label"),
                   "Input(Label) of CenterLoss should not be null.");

    PADDLE_ENFORCE(ctx->HasInput("Centers"),
                   "Input(Centers) of CenterLoss should not be null.");

    PADDLE_ENFORCE(
        ctx->HasOutput("SampleCenterDiff"),
        "Output(SampleCenterDiff) of CenterLoss should not be null.");

    PADDLE_ENFORCE(ctx->HasOutput("Loss"),
                   "Output(Loss) of CenterLoss should not be null.");

    PADDLE_ENFORCE(
        ctx->HasOutput("CentersOut"),
        "Output(CentersOut) of CenterLoss shared data with Centers.");

    ctx->SetOutputDim("SampleCenterDiff",
                      {x_dims[0], product(x_dims) / x_dims[0]});
    ctx->SetOutputDim("CentersOut", ctx->GetInputDim("Centers"));
    ctx->SetOutputDim("Loss", {x_dims[0], 1});
    ctx->ShareLoD("X", /*->*/ "Loss");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
64 65 66
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
H
HaoRen 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
  }
};

class CenterLossOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input tensor of center_loss operator.");
    AddInput("Label", "(Tensor) Input tensor of center_loss operator.");
    AddInput("Centers", "(Tensor) Input tensor of center_loss operator.");
    AddInput("CenterUpdateRate",
             "(Tensor) Input tensor of center_loss operator.");

    AddOutput("CentersOut", "(Tensor) Input tensor of center_loss operator.");
    AddOutput("SampleCenterDiff",
              "(Tensor) output tensor of center_loss operator.");
    AddOutput("Loss", "(Tensor) Output tensor of center_loss operator.");

    AddAttr<int>("cluster_num",
                 "The output cluster num of the center_loss operator.");
    AddAttr<bool>("need_update", "whether need to update center info.");
    AddComment(R"DOC(
**CenterLoss operator**
implemention of the center loss function in the papper<<A Discriminative 
Feature Learning Approach for Deep Face Recognition>>, equations in this  implement
is:loss = 1/2 * (x-y)^2 ,where x(X) means the deep feature(output of last hidden layer )
and y(Label) the target label 
)DOC");
  }
};

class CenterLossGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("SampleCenterDiff"),
                   "Input(SampleCenterDiff) should not be null");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss) should not be null");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
                   "Output(X) should not be null");

    auto x_dims = ctx->GetInputDim("X");
    auto x_grad_name = framework::GradVarName("X");

    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    return framework::OpKernelType(
121 122
        OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"),
        ctx.device_context());
H
HaoRen 已提交
123 124 125
  }
};

H
hong 已提交
126 127
template <typename T>
class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
H
HaoRen 已提交
128
 public:
H
hong 已提交
129
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
H
HaoRen 已提交
130 131

 protected:
132
  void Apply(GradOpPtr<T> retv) const override {
H
HaoRen 已提交
133
    retv->SetType("center_loss_grad");
H
hong 已提交
134 135 136 137
    retv->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    retv->SetInput("SampleCenterDiff", this->Output("SampleCenterDiff"));
    retv->SetInput("X", this->Input("X"));
    retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
H
HaoRen 已提交
138

H
hong 已提交
139
    retv->SetAttrMap(this->Attrs());
H
HaoRen 已提交
140 141
  }
};
142

143
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CenterLossGradNoNeedBufVarsInferer, "X");
144

H
HaoRen 已提交
145 146 147 148 149 150 151
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext;

REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker,
H
hong 已提交
152 153
                  ops::CenterLossOpGradMaker<paddle::framework::OpDesc>,
                  ops::CenterLossOpGradMaker<paddle::imperative::OpBase>);
H
HaoRen 已提交
154

155 156
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp,
                  ops::CenterLossGradNoNeedBufVarsInferer);
H
HaoRen 已提交
157 158 159 160 161 162 163

REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel<CPUCtx, float>,
                       ops::CenterLossKernel<CPUCtx, double>);

REGISTER_OP_CPU_KERNEL(center_loss_grad,
                       ops::CenterLossGradKernel<CPUCtx, float>,
                       ops::CenterLossGradKernel<CPUCtx, double>);