nce_op.cc 14.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

W
wanghaoshuang 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

W
wanghaoshuang 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

W
wanghaoshuang 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/nce_op.h"
W
wanghaoshuang 已提交
16

17
#include <memory>
18
#include <string>
Y
Yang Yang 已提交
19 20
#include <vector>

W
wanghaoshuang 已提交
21 22 23 24 25 26 27 28 29
namespace paddle {
namespace operators {

using framework::Tensor;

class NCEOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

30
  void InferShape(framework::InferShapeContext *ctx) const override {
31 32 33 34 35
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "nce");
    OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "nce");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce");

    OP_INOUT_CHECK(ctx->HasOutput("Cost"), "Output", "Cost", "nce");
P
pangyoki 已提交
36 37
    bool is_test = ctx->Attrs().Get<bool>("is_test");
    if (!is_test) {
38 39 40 41
      OP_INOUT_CHECK(
          ctx->HasOutput("SampleLogits"), "Output", "SampleLogits", "nce");
      OP_INOUT_CHECK(
          ctx->HasOutput("SampleLabels"), "Output", "SampleLabels", "nce");
P
pangyoki 已提交
42
    }
W
wanghaoshuang 已提交
43

W
wanghaoshuang 已提交
44
    auto x_dims = ctx->GetInputDim("Input");
W
wanghaoshuang 已提交
45
    auto label_dims = ctx->GetInputDim("Label");
46
    if (ctx->IsRuntime() || (x_dims[0] > 0 && label_dims[0] > 0)) {
47
      PADDLE_ENFORCE_EQ(
48 49
          x_dims[0],
          label_dims[0],
50 51 52 53 54
          platform::errors::InvalidArgument(
              "The first dimension of Input(Input) and Input(Label) should be "
              "equal in runtime. But received: Input(Input)'s shape = [%s] "
              "with 1st dim =  %d, Input(Label)'s shape = [%s] with 1st dim = "
              "%d.",
55 56 57 58
              x_dims,
              x_dims[0],
              label_dims,
              label_dims[0]));
59
    }
W
wanghaoshuang 已提交
60 61
    int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
    if (ctx->HasInput("Bias")) {
62
      PADDLE_ENFORCE_EQ(
63 64
          ctx->GetInputDim("Weight")[0],
          ctx->GetInputDim("Bias")[0],
65 66 67 68 69
          platform::errors::InvalidArgument(
              "The first dimension of Input(Weight) and Input(Bias) "
              "should be equal. But received: Input(Weight)'s shape = [%s] "
              "with 1st dim = %d, and Input(Bias)'s shape = [%s] with 1st dim "
              "= %d.",
70 71 72 73
              ctx->GetInputDim("Weight"),
              ctx->GetInputDim("Weight")[0],
              ctx->GetInputDim("Bias"),
              ctx->GetInputDim("Bias")[0]));
W
wanghaoshuang 已提交
74
    }
W
wanghaoshuang 已提交
75 76
    auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
    auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
W
wanghaoshuang 已提交
77 78
    std::vector<int> custom_neg_classes =
        ctx->Attrs().Get<std::vector<int>>("custom_neg_classes");
79
    PADDLE_ENFORCE_EQ(
80 81
        num_total_classes,
        ctx->GetInputDim("Weight")[0],
82 83 84 85
        platform::errors::InvalidArgument(
            "The number of total classes should be equal to the first "
            "dimension of Input(Weight). But received: Attr(num_total_classes) "
            "= %d, Input(Weight)'s shape = [%s] with 1st dim = %d.",
86 87
            num_total_classes,
            ctx->GetInputDim("Weight"),
88
            ctx->GetInputDim("Weight")[0]));
W
wanghaoshuang 已提交
89
    if (custom_neg_classes.size() > 0) {
90
      PADDLE_ENFORCE_EQ(
91 92
          custom_neg_classes.size(),
          static_cast<size_t>(num_neg_samples),
93 94 95 96
          platform::errors::InvalidArgument(
              "The size of Attr(custom_neg_classes) should be equal "
              "to the number of negative samples. But received: "
              "custom_neg_classes.size() = %d, num_neg_samples = %d.",
97 98
              custom_neg_classes.size(),
              num_neg_samples));
W
wanghaoshuang 已提交
99
    }
W
wanghaoshuang 已提交
100
    // set dims of output(Out)
W
wanghaoshuang 已提交
101
    std::vector<int64_t> out_dims;
W
wanghaoshuang 已提交
102
    out_dims.push_back(x_dims[0]);
W
wanghaoshuang 已提交
103
    out_dims.push_back(1);
104
    ctx->SetOutputDim("Cost", phi::make_ddim(out_dims));
W
wanghaoshuang 已提交
105

P
pangyoki 已提交
106 107 108 109 110 111
    if (!is_test) {
      // set dims of output(SampleOut)
      std::vector<int64_t> sample_out_dims;
      sample_out_dims.push_back(x_dims[0]);
      sample_out_dims.push_back(
          (num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));
112 113
      ctx->SetOutputDim("SampleLogits", phi::make_ddim(sample_out_dims));
      ctx->SetOutputDim("SampleLabels", phi::make_ddim(sample_out_dims));
P
pangyoki 已提交
114
    }
W
wanghaoshuang 已提交
115
  }
W
wanghaoshuang 已提交
116 117

 protected:
118
  framework::OpKernelType GetExpectedKernelType(
119
      const framework::ExecutionContext &ctx) const override {
120 121 122
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        platform::CPUPlace());
W
wanghaoshuang 已提交
123
  }
W
wanghaoshuang 已提交
124 125 126 127
};

class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
128
  void Make() override {
W
wanghaoshuang 已提交
129
    AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
W
wanghaoshuang 已提交
130 131 132 133 134 135 136 137
    AddInput(
        "Label",
        "(Tensor) A tensor of shape [batch_size, num_true_class]. "
        "'num_true_class' is the number of target classes in each sample."
        "The number of target classes per sample should be same. "
        "If you have a variable number of target classes, "
        "you can pad them out to a constant number by either repeating them"
        " or by padding with an otherwise unused class.)");
W
wanghaoshuang 已提交
138 139 140
    AddInput("Weight",
             "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the "
             "total number of class.");
W
wanghaoshuang 已提交
141 142 143 144
    AddInput(
        "Bias",
        "(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total "
        "number of class. It is a dispensable input.")
W
wanghaoshuang 已提交
145 146
        .AsDispensable();
    AddInput("SampleWeight",
W
wanghaoshuang 已提交
147
             "(Tensor) A tensor of shape [batch_size, 1] storing a weight for "
W
wanghaoshuang 已提交
148 149 150
             "each sample. And it is a dispensable input. The default value of "
             "sample is 1.")
        .AsDispensable();
151 152

    AddInput(
153
        "CustomDistProbs",
154 155
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
156
        "The i-th element is the probability of the i-th class being sampled.")
157
        .AsDispensable();
158 159 160 161
    AddInput(
        "CustomDistAlias",
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
162
        "The i-th element is the probability of the i-th class being sampled.")
163 164 165 166 167
        .AsDispensable();
    AddInput(
        "CustomDistAliasProbs",
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
168
        "The i-th element is the probability of the i-th class being sampled.")
169 170
        .AsDispensable();

W
wanghaoshuang 已提交
171
    AddOutput("Cost",
W
wanghaoshuang 已提交
172 173 174 175 176 177 178 179 180
              "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
    AddOutput("SampleLogits",
              "An intermediate tensor of shape[batch_size, num_neg_samples + "
              "num_pos_samples]."
              "This tensor is output of forward kernel and used in backward "
              "kernel to compute grads."
              "Given X is  the dot product of input tensor and sampled labels' "
              "weights."
              "Then 'SampleLogits' is sigmoid(X).")
P
pangyoki 已提交
181 182
        .AsIntermediate()
        .AsExtra();
W
wanghaoshuang 已提交
183 184 185 186 187 188
    AddOutput("SampleLabels",
              "An intermediate tensor of shape[batch_size, num_neg_samples + "
              "num_pos_samples]."
              "This tensor is output of forward kernel and used in backward "
              "kernel to compute grads."
              "")
P
pangyoki 已提交
189 190
        .AsIntermediate()
        .AsExtra();
191

W
wanghaoshuang 已提交
192 193 194 195
    AddAttr<int>("num_total_classes",
                 "Total number of classes in all samples.");
    AddAttr<int>("num_neg_samples",
                 "The number of negative classes. The default value is 10.")
W
wanghaoshuang 已提交
196
        .SetDefault(10);
197 198 199 200 201 202 203 204
    AddAttr<int>("sampler",
                 "(int) Which sampler to be used to sample negative class."
                 "0: Uniform; 1: LogUniform; 2: CostumDist.")
        .SetDefault(0);
    AddAttr<int>("seed",
                 "(int) The seed used in sampler. If it is 0, "
                 "the sampler will generate a seed randomly.")
        .SetDefault(0);
205 206
    AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
        .SetDefault(false);
207

T
tangwei12 已提交
208 209
    // for parameter prefetch
    AddAttr<bool>("remote_prefetch", "").SetDefault(false);
P
pangyoki 已提交
210 211 212
    AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.")
        .SetDefault(0)
        .AsExtra();
Q
Qiao Longfei 已提交
213 214
    AddAttr<std::vector<int64_t>>("height_sections",
                                  "Height for each output SelectedRows.")
P
pangyoki 已提交
215 216
        .SetDefault(std::vector<int64_t>({}))
        .AsExtra();
T
tangwei12 已提交
217 218 219 220
    AddAttr<std::vector<std::string>>(
        "epmap",
        "(string vector, default 127.0.0.1:6164)"
        "Server endpoints in the order of input variables for mapping")
P
pangyoki 已提交
221 222
        .SetDefault({})
        .AsExtra();
T
tangwei12 已提交
223 224
    AddAttr<std::vector<std::string>>(
        "table_names",
T
tianshuo78520a 已提交
225
        "(string vector, the split table names that will be fetched from "
T
tangwei12 已提交
226 227
        "parameter server)"
        "in the order of input variables for mapping")
P
pangyoki 已提交
228 229
        .SetDefault({})
        .AsExtra();
T
tangwei12 已提交
230

W
wanghaoshuang 已提交
231 232 233 234
    AddAttr<std::vector<int>>("custom_neg_classes",
                              "This attribute only be used in unitest. Classes "
                              "in this list wiil be used as negative classes "
                              "for every samples. Under normal conditions, "
Y
Yang Yu 已提交
235
                              "user should avoid setting this attribute.")
P
pangyoki 已提交
236 237 238 239 240 241
        .SetDefault({})
        .AsExtra();
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference "
                  "only, false for training.")
        .SetDefault(false);
W
wanghaoshuang 已提交
242
    AddComment(R"DOC(
M
minqiyang 已提交
243 244 245
Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized
statistical models
Y
Yibing Liu 已提交
246
 <http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf>`_.
W
wanghaoshuang 已提交
247
By default this operator uses a uniform distribution for sampling.
W
wanghaoshuang 已提交
248 249 250 251
)DOC");
  }
};

252 253 254 255
template <typename T>
class NCEGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
256
  void Apply(GradOpPtr<T> op) const override {
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Label", this->Input("Label"));
    op->SetInput("Bias", this->Input("Bias"));
    op->SetInput("Weight", this->Input("Weight"));
    op->SetInput("SampleLogits", this->Output("SampleLogits"));
    op->SetInput("SampleLabels", this->Output("SampleLabels"));
    op->SetInput("SampleWeight", this->Input("SampleWeight"));
    op->SetInput("CustomDistProbs", this->Input("CustomDistProbs"));
    op->SetInput("CustomDistAlias", this->Input("CustomDistAlias"));
    op->SetInput("CustomDistAliasProbs", this->Input("CustomDistAliasProbs"));
    op->SetInput(framework::GradVarName("Cost"), this->OutputGrad("Cost"));
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
    op->SetAttrMap(this->Attrs());
  }
};

W
wanghaoshuang 已提交
276 277 278 279
class NCEOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

280
  void InferShape(framework::InferShapeContext *ctx) const override {
281 282
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce_grad");
283 284 285 286 287 288 289
    OP_INOUT_CHECK(
        ctx->HasInput("SampleLogits"), "Input", "SampleLogits", "nce_grad");
    OP_INOUT_CHECK(
        ctx->HasInput("SampleLabels"), "Input", "SampleLabels", "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Cost")),
                   "Input",
                   framework::GradVarName("Cost"),
290
                   "nce_grad");
W
wanghaoshuang 已提交
291

W
wanghaoshuang 已提交
292 293
    auto x_dims = ctx->GetInputDim("Input");
    auto x_grad_name = framework::GradVarName("Input");
W
wanghaoshuang 已提交
294 295 296 297
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }

W
wanghaoshuang 已提交
298 299
    auto w_dims = ctx->GetInputDim("Weight");
    auto w_grad_name = framework::GradVarName("Weight");
W
wanghaoshuang 已提交
300 301 302 303
    if (ctx->HasOutput(w_grad_name)) {
      ctx->SetOutputDim(w_grad_name, w_dims);
    }

W
wanghaoshuang 已提交
304
    auto bias_grad_name = framework::GradVarName("Bias");
W
wanghaoshuang 已提交
305
    if (ctx->HasOutput(bias_grad_name)) {
W
wanghaoshuang 已提交
306
      auto bias_dims = ctx->GetInputDim("Bias");
W
wanghaoshuang 已提交
307 308 309
      ctx->SetOutputDim(bias_grad_name, bias_dims);
    }
  }
W
wanghaoshuang 已提交
310 311

 protected:
312
  framework::OpKernelType GetExpectedKernelType(
313
      const framework::ExecutionContext &ctx) const override {
314 315 316
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        platform::CPUPlace());
W
wanghaoshuang 已提交
317
  }
W
wanghaoshuang 已提交
318 319
};

320 321
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
322
  void operator()(framework::InferVarTypeContext *ctx) const override {
323
    auto weight_grad = framework::GradVarName("Weight");
324

M
minqiyang 已提交
325
    auto attr = ctx->GetAttr("is_sparse");
326
    bool is_sparse = BOOST_GET(bool, attr);
327
    if (is_sparse) {
328
      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
M
minqiyang 已提交
329
              << " is set to SelectedRows";
330
      ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
331
    } else {
332
      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
M
minqiyang 已提交
333
              << " is set to LoDTensor";
334
      ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR);
335
    }
336
    ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input"));
337 338 339
  }
};

340
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInferer, "Bias");
341

W
wanghaoshuang 已提交
342 343 344 345
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
346 347 348
REGISTER_OPERATOR(nce,
                  ops::NCEOp,
                  ops::NCEOpMaker,
349 350
                  ops::NCEGradOpMaker<paddle::framework::OpDesc>,
                  ops::NCEGradOpMaker<paddle::imperative::OpBase>);
351 352 353
REGISTER_OPERATOR(nce_grad,
                  ops::NCEOpGrad,
                  ops::NCEOpGradVarTypeInference,
354
                  ops::NCEGradOpNoNeedBufferVarInferer);
355 356
REGISTER_OP_CPU_KERNEL(nce,
                       ops::NCEKernel<paddle::platform::CPUPlace, float>,
W
wanghaoshuang 已提交
357
                       ops::NCEKernel<paddle::platform::CPUPlace, double>);
W
wanghaoshuang 已提交
358
REGISTER_OP_CPU_KERNEL(nce_grad,
W
wanghaoshuang 已提交
359 360
                       ops::NCEGradKernel<paddle::platform::CPUPlace, float>,
                       ops::NCEGradKernel<paddle::platform::CPUPlace, double>);