nce_op.cc 13.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

W
wanghaoshuang 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

W
wanghaoshuang 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

W
wanghaoshuang 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/nce_op.h"
W
wanghaoshuang 已提交
16

17
#include <memory>
18
#include <string>
Y
Yang Yang 已提交
19 20
#include <vector>

W
wanghaoshuang 已提交
21 22 23 24 25 26 27 28 29
namespace paddle {
namespace operators {

using framework::Tensor;

class NCEOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

30
  void InferShape(framework::InferShapeContext *ctx) const override {
31 32 33 34 35
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "nce");
    OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "nce");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce");

    OP_INOUT_CHECK(ctx->HasOutput("Cost"), "Output", "Cost", "nce");
P
pangyoki 已提交
36 37
    bool is_test = ctx->Attrs().Get<bool>("is_test");
    if (!is_test) {
38 39 40 41
      OP_INOUT_CHECK(
          ctx->HasOutput("SampleLogits"), "Output", "SampleLogits", "nce");
      OP_INOUT_CHECK(
          ctx->HasOutput("SampleLabels"), "Output", "SampleLabels", "nce");
P
pangyoki 已提交
42
    }
W
wanghaoshuang 已提交
43

W
wanghaoshuang 已提交
44
    auto x_dims = ctx->GetInputDim("Input");
W
wanghaoshuang 已提交
45
    auto label_dims = ctx->GetInputDim("Label");
46
    if (ctx->IsRuntime() || (x_dims[0] > 0 && label_dims[0] > 0)) {
47
      PADDLE_ENFORCE_EQ(
48 49
          x_dims[0],
          label_dims[0],
50 51 52 53 54
          platform::errors::InvalidArgument(
              "The first dimension of Input(Input) and Input(Label) should be "
              "equal in runtime. But received: Input(Input)'s shape = [%s] "
              "with 1st dim =  %d, Input(Label)'s shape = [%s] with 1st dim = "
              "%d.",
55 56 57 58
              x_dims,
              x_dims[0],
              label_dims,
              label_dims[0]));
59
    }
W
wanghaoshuang 已提交
60 61
    int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
    if (ctx->HasInput("Bias")) {
62
      PADDLE_ENFORCE_EQ(
63 64
          ctx->GetInputDim("Weight")[0],
          ctx->GetInputDim("Bias")[0],
65 66 67 68 69
          platform::errors::InvalidArgument(
              "The first dimension of Input(Weight) and Input(Bias) "
              "should be equal. But received: Input(Weight)'s shape = [%s] "
              "with 1st dim = %d, and Input(Bias)'s shape = [%s] with 1st dim "
              "= %d.",
70 71 72 73
              ctx->GetInputDim("Weight"),
              ctx->GetInputDim("Weight")[0],
              ctx->GetInputDim("Bias"),
              ctx->GetInputDim("Bias")[0]));
W
wanghaoshuang 已提交
74
    }
W
wanghaoshuang 已提交
75 76
    auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
    auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
W
wanghaoshuang 已提交
77 78
    std::vector<int> custom_neg_classes =
        ctx->Attrs().Get<std::vector<int>>("custom_neg_classes");
79
    PADDLE_ENFORCE_EQ(
80 81
        num_total_classes,
        ctx->GetInputDim("Weight")[0],
82 83 84 85
        platform::errors::InvalidArgument(
            "The number of total classes should be equal to the first "
            "dimension of Input(Weight). But received: Attr(num_total_classes) "
            "= %d, Input(Weight)'s shape = [%s] with 1st dim = %d.",
86 87
            num_total_classes,
            ctx->GetInputDim("Weight"),
88
            ctx->GetInputDim("Weight")[0]));
W
wanghaoshuang 已提交
89
    if (custom_neg_classes.size() > 0) {
90
      PADDLE_ENFORCE_EQ(
91 92
          custom_neg_classes.size(),
          static_cast<size_t>(num_neg_samples),
93 94 95 96
          platform::errors::InvalidArgument(
              "The size of Attr(custom_neg_classes) should be equal "
              "to the number of negative samples. But received: "
              "custom_neg_classes.size() = %d, num_neg_samples = %d.",
97 98
              custom_neg_classes.size(),
              num_neg_samples));
W
wanghaoshuang 已提交
99
    }
W
wanghaoshuang 已提交
100
    // set dims of output(Out)
W
wanghaoshuang 已提交
101
    std::vector<int64_t> out_dims;
W
wanghaoshuang 已提交
102
    out_dims.push_back(x_dims[0]);
W
wanghaoshuang 已提交
103
    out_dims.push_back(1);
104
    ctx->SetOutputDim("Cost", phi::make_ddim(out_dims));
W
wanghaoshuang 已提交
105

P
pangyoki 已提交
106 107 108 109 110 111
    if (!is_test) {
      // set dims of output(SampleOut)
      std::vector<int64_t> sample_out_dims;
      sample_out_dims.push_back(x_dims[0]);
      sample_out_dims.push_back(
          (num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));
112 113
      ctx->SetOutputDim("SampleLogits", phi::make_ddim(sample_out_dims));
      ctx->SetOutputDim("SampleLabels", phi::make_ddim(sample_out_dims));
P
pangyoki 已提交
114
    }
W
wanghaoshuang 已提交
115
  }
W
wanghaoshuang 已提交
116 117

 protected:
118
  framework::OpKernelType GetExpectedKernelType(
119
      const framework::ExecutionContext &ctx) const override {
120 121 122
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        platform::CPUPlace());
W
wanghaoshuang 已提交
123
  }
W
wanghaoshuang 已提交
124 125 126 127
};

class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
128
  void Make() override {
W
wanghaoshuang 已提交
129
    AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
W
wanghaoshuang 已提交
130 131 132 133 134 135 136 137
    AddInput(
        "Label",
        "(Tensor) A tensor of shape [batch_size, num_true_class]. "
        "'num_true_class' is the number of target classes in each sample."
        "The number of target classes per sample should be same. "
        "If you have a variable number of target classes, "
        "you can pad them out to a constant number by either repeating them"
        " or by padding with an otherwise unused class.)");
W
wanghaoshuang 已提交
138 139 140
    AddInput("Weight",
             "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the "
             "total number of class.");
W
wanghaoshuang 已提交
141 142 143 144
    AddInput(
        "Bias",
        "(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total "
        "number of class. It is a dispensable input.")
W
wanghaoshuang 已提交
145 146
        .AsDispensable();
    AddInput("SampleWeight",
W
wanghaoshuang 已提交
147
             "(Tensor) A tensor of shape [batch_size, 1] storing a weight for "
W
wanghaoshuang 已提交
148 149 150
             "each sample. And it is a dispensable input. The default value of "
             "sample is 1.")
        .AsDispensable();
151 152

    AddInput(
153
        "CustomDistProbs",
154 155
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
156
        "The i-th element is the probability of the i-th class being sampled.")
157
        .AsDispensable();
158 159 160 161
    AddInput(
        "CustomDistAlias",
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
162
        "The i-th element is the probability of the i-th class being sampled.")
163 164 165 166 167
        .AsDispensable();
    AddInput(
        "CustomDistAliasProbs",
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
168
        "The i-th element is the probability of the i-th class being sampled.")
169 170
        .AsDispensable();

W
wanghaoshuang 已提交
171
    AddOutput("Cost",
W
wanghaoshuang 已提交
172 173 174 175 176 177 178 179 180
              "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
    AddOutput("SampleLogits",
              "An intermediate tensor of shape[batch_size, num_neg_samples + "
              "num_pos_samples]."
              "This tensor is output of forward kernel and used in backward "
              "kernel to compute grads."
              "Given X is  the dot product of input tensor and sampled labels' "
              "weights."
              "Then 'SampleLogits' is sigmoid(X).")
P
pangyoki 已提交
181 182
        .AsIntermediate()
        .AsExtra();
W
wanghaoshuang 已提交
183 184 185 186 187 188
    AddOutput("SampleLabels",
              "An intermediate tensor of shape[batch_size, num_neg_samples + "
              "num_pos_samples]."
              "This tensor is output of forward kernel and used in backward "
              "kernel to compute grads."
              "")
P
pangyoki 已提交
189 190
        .AsIntermediate()
        .AsExtra();
191

W
wanghaoshuang 已提交
192 193 194 195
    AddAttr<int>("num_total_classes",
                 "Total number of classes in all samples.");
    AddAttr<int>("num_neg_samples",
                 "The number of negative classes. The default value is 10.")
W
wanghaoshuang 已提交
196
        .SetDefault(10);
197 198 199 200 201 202 203 204
    AddAttr<int>("sampler",
                 "(int) Which sampler to be used to sample negative class."
                 "0: Uniform; 1: LogUniform; 2: CostumDist.")
        .SetDefault(0);
    AddAttr<int>("seed",
                 "(int) The seed used in sampler. If it is 0, "
                 "the sampler will generate a seed randomly.")
        .SetDefault(0);
205 206
    AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
        .SetDefault(false);
207

T
tangwei12 已提交
208 209
    // for parameter prefetch
    AddAttr<bool>("remote_prefetch", "").SetDefault(false);
P
pangyoki 已提交
210 211 212 213
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference "
                  "only, false for training.")
        .SetDefault(false);
W
wanghaoshuang 已提交
214
    AddComment(R"DOC(
M
minqiyang 已提交
215 216 217
Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized
statistical models
Y
Yibing Liu 已提交
218
 <http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf>`_.
W
wanghaoshuang 已提交
219
By default this operator uses a uniform distribution for sampling.
W
wanghaoshuang 已提交
220 221 222 223
)DOC");
  }
};

224 225 226 227
template <typename T>
class NCEGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
228
  void Apply(GradOpPtr<T> op) const override {
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Label", this->Input("Label"));
    op->SetInput("Bias", this->Input("Bias"));
    op->SetInput("Weight", this->Input("Weight"));
    op->SetInput("SampleLogits", this->Output("SampleLogits"));
    op->SetInput("SampleLabels", this->Output("SampleLabels"));
    op->SetInput("SampleWeight", this->Input("SampleWeight"));
    op->SetInput("CustomDistProbs", this->Input("CustomDistProbs"));
    op->SetInput("CustomDistAlias", this->Input("CustomDistAlias"));
    op->SetInput("CustomDistAliasProbs", this->Input("CustomDistAliasProbs"));
    op->SetInput(framework::GradVarName("Cost"), this->OutputGrad("Cost"));
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
    op->SetAttrMap(this->Attrs());
  }
};

W
wanghaoshuang 已提交
248 249 250 251
class NCEOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

252
  void InferShape(framework::InferShapeContext *ctx) const override {
253 254
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce_grad");
255 256 257 258 259 260 261
    OP_INOUT_CHECK(
        ctx->HasInput("SampleLogits"), "Input", "SampleLogits", "nce_grad");
    OP_INOUT_CHECK(
        ctx->HasInput("SampleLabels"), "Input", "SampleLabels", "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Cost")),
                   "Input",
                   framework::GradVarName("Cost"),
262
                   "nce_grad");
W
wanghaoshuang 已提交
263

W
wanghaoshuang 已提交
264 265
    auto x_dims = ctx->GetInputDim("Input");
    auto x_grad_name = framework::GradVarName("Input");
W
wanghaoshuang 已提交
266 267 268 269
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }

W
wanghaoshuang 已提交
270 271
    auto w_dims = ctx->GetInputDim("Weight");
    auto w_grad_name = framework::GradVarName("Weight");
W
wanghaoshuang 已提交
272 273 274 275
    if (ctx->HasOutput(w_grad_name)) {
      ctx->SetOutputDim(w_grad_name, w_dims);
    }

W
wanghaoshuang 已提交
276
    auto bias_grad_name = framework::GradVarName("Bias");
W
wanghaoshuang 已提交
277
    if (ctx->HasOutput(bias_grad_name)) {
W
wanghaoshuang 已提交
278
      auto bias_dims = ctx->GetInputDim("Bias");
W
wanghaoshuang 已提交
279 280 281
      ctx->SetOutputDim(bias_grad_name, bias_dims);
    }
  }
W
wanghaoshuang 已提交
282 283

 protected:
284
  framework::OpKernelType GetExpectedKernelType(
285
      const framework::ExecutionContext &ctx) const override {
286 287 288
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        platform::CPUPlace());
W
wanghaoshuang 已提交
289
  }
W
wanghaoshuang 已提交
290 291
};

292 293
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
294
  void operator()(framework::InferVarTypeContext *ctx) const override {
295
    auto weight_grad = framework::GradVarName("Weight");
296

M
minqiyang 已提交
297
    auto attr = ctx->GetAttr("is_sparse");
R
Ruibiao Chen 已提交
298
    bool is_sparse = PADDLE_GET(bool, attr);
299
    if (is_sparse) {
300
      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
M
minqiyang 已提交
301
              << " is set to SelectedRows";
302
      ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
303
    } else {
304
      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
M
minqiyang 已提交
305
              << " is set to LoDTensor";
306
      ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR);
307
    }
308
    ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input"));
309 310 311
  }
};

312
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInferer, "Bias");
313

W
wanghaoshuang 已提交
314 315 316 317
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
318 319 320
REGISTER_OPERATOR(nce,
                  ops::NCEOp,
                  ops::NCEOpMaker,
321 322
                  ops::NCEGradOpMaker<paddle::framework::OpDesc>,
                  ops::NCEGradOpMaker<paddle::imperative::OpBase>);
323 324 325
REGISTER_OPERATOR(nce_grad,
                  ops::NCEOpGrad,
                  ops::NCEOpGradVarTypeInference,
326
                  ops::NCEGradOpNoNeedBufferVarInferer);
327 328
REGISTER_OP_CPU_KERNEL(nce,
                       ops::NCEKernel<paddle::platform::CPUPlace, float>,
W
wanghaoshuang 已提交
329
                       ops::NCEKernel<paddle::platform::CPUPlace, double>);
W
wanghaoshuang 已提交
330
REGISTER_OP_CPU_KERNEL(nce_grad,
W
wanghaoshuang 已提交
331 332
                       ops::NCEGradKernel<paddle::platform::CPUPlace, float>,
                       ops::NCEGradKernel<paddle::platform::CPUPlace, double>);