nce_op.cc 14.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

W
wanghaoshuang 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

W
wanghaoshuang 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

W
wanghaoshuang 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/nce_op.h"
W
wanghaoshuang 已提交
16

17
#include <memory>
18
#include <string>
Y
Yang Yang 已提交
19 20
#include <vector>

W
wanghaoshuang 已提交
21 22 23 24 25 26 27 28 29
namespace paddle {
namespace operators {

using framework::Tensor;

class NCEOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

30
  void InferShape(framework::InferShapeContext *ctx) const override {
31 32 33 34 35
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "nce");
    OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "nce");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce");

    OP_INOUT_CHECK(ctx->HasOutput("Cost"), "Output", "Cost", "nce");
P
pangyoki 已提交
36 37 38 39 40 41 42
    bool is_test = ctx->Attrs().Get<bool>("is_test");
    if (!is_test) {
      OP_INOUT_CHECK(ctx->HasOutput("SampleLogits"), "Output", "SampleLogits",
                     "nce");
      OP_INOUT_CHECK(ctx->HasOutput("SampleLabels"), "Output", "SampleLabels",
                     "nce");
    }
W
wanghaoshuang 已提交
43

W
wanghaoshuang 已提交
44
    auto x_dims = ctx->GetInputDim("Input");
W
wanghaoshuang 已提交
45
    auto label_dims = ctx->GetInputDim("Label");
46
    if (ctx->IsRuntime() || (x_dims[0] > 0 && label_dims[0] > 0)) {
47 48
      PADDLE_ENFORCE_EQ(
          x_dims[0], label_dims[0],
49 50 51 52 53 54
          platform::errors::InvalidArgument(
              "The first dimension of Input(Input) and Input(Label) should be "
              "equal in runtime. But received: Input(Input)'s shape = [%s] "
              "with 1st dim =  %d, Input(Label)'s shape = [%s] with 1st dim = "
              "%d.",
              x_dims, x_dims[0], label_dims, label_dims[0]));
55
    }
W
wanghaoshuang 已提交
56 57
    int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
    if (ctx->HasInput("Bias")) {
58 59
      PADDLE_ENFORCE_EQ(
          ctx->GetInputDim("Weight")[0], ctx->GetInputDim("Bias")[0],
60 61 62 63 64 65 66
          platform::errors::InvalidArgument(
              "The first dimension of Input(Weight) and Input(Bias) "
              "should be equal. But received: Input(Weight)'s shape = [%s] "
              "with 1st dim = %d, and Input(Bias)'s shape = [%s] with 1st dim "
              "= %d.",
              ctx->GetInputDim("Weight"), ctx->GetInputDim("Weight")[0],
              ctx->GetInputDim("Bias"), ctx->GetInputDim("Bias")[0]));
W
wanghaoshuang 已提交
67
    }
W
wanghaoshuang 已提交
68 69
    auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
    auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
W
wanghaoshuang 已提交
70 71
    std::vector<int> custom_neg_classes =
        ctx->Attrs().Get<std::vector<int>>("custom_neg_classes");
72 73
    PADDLE_ENFORCE_EQ(
        num_total_classes, ctx->GetInputDim("Weight")[0],
74 75 76 77 78 79
        platform::errors::InvalidArgument(
            "The number of total classes should be equal to the first "
            "dimension of Input(Weight). But received: Attr(num_total_classes) "
            "= %d, Input(Weight)'s shape = [%s] with 1st dim = %d.",
            num_total_classes, ctx->GetInputDim("Weight"),
            ctx->GetInputDim("Weight")[0]));
W
wanghaoshuang 已提交
80
    if (custom_neg_classes.size() > 0) {
81 82
      PADDLE_ENFORCE_EQ(
          custom_neg_classes.size(), static_cast<size_t>(num_neg_samples),
83 84 85 86 87
          platform::errors::InvalidArgument(
              "The size of Attr(custom_neg_classes) should be equal "
              "to the number of negative samples. But received: "
              "custom_neg_classes.size() = %d, num_neg_samples = %d.",
              custom_neg_classes.size(), num_neg_samples));
W
wanghaoshuang 已提交
88
    }
W
wanghaoshuang 已提交
89
    // set dims of output(Out)
W
wanghaoshuang 已提交
90
    std::vector<int64_t> out_dims;
W
wanghaoshuang 已提交
91
    out_dims.push_back(x_dims[0]);
W
wanghaoshuang 已提交
92
    out_dims.push_back(1);
93
    ctx->SetOutputDim("Cost", pten::make_ddim(out_dims));
W
wanghaoshuang 已提交
94

P
pangyoki 已提交
95 96 97 98 99 100
    if (!is_test) {
      // set dims of output(SampleOut)
      std::vector<int64_t> sample_out_dims;
      sample_out_dims.push_back(x_dims[0]);
      sample_out_dims.push_back(
          (num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes));
101 102
      ctx->SetOutputDim("SampleLogits", pten::make_ddim(sample_out_dims));
      ctx->SetOutputDim("SampleLabels", pten::make_ddim(sample_out_dims));
P
pangyoki 已提交
103
    }
W
wanghaoshuang 已提交
104
  }
W
wanghaoshuang 已提交
105 106

 protected:
107
  framework::OpKernelType GetExpectedKernelType(
108
      const framework::ExecutionContext &ctx) const override {
109 110 111
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        platform::CPUPlace());
W
wanghaoshuang 已提交
112
  }
W
wanghaoshuang 已提交
113 114 115 116
};

class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
117
  void Make() override {
W
wanghaoshuang 已提交
118
    AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
W
wanghaoshuang 已提交
119 120 121 122 123 124 125 126
    AddInput(
        "Label",
        "(Tensor) A tensor of shape [batch_size, num_true_class]. "
        "'num_true_class' is the number of target classes in each sample."
        "The number of target classes per sample should be same. "
        "If you have a variable number of target classes, "
        "you can pad them out to a constant number by either repeating them"
        " or by padding with an otherwise unused class.)");
W
wanghaoshuang 已提交
127 128 129
    AddInput("Weight",
             "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the "
             "total number of class.");
W
wanghaoshuang 已提交
130 131 132 133
    AddInput(
        "Bias",
        "(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total "
        "number of class. It is a dispensable input.")
W
wanghaoshuang 已提交
134 135
        .AsDispensable();
    AddInput("SampleWeight",
W
wanghaoshuang 已提交
136
             "(Tensor) A tensor of shape [batch_size, 1] storing a weight for "
W
wanghaoshuang 已提交
137 138 139
             "each sample. And it is a dispensable input. The default value of "
             "sample is 1.")
        .AsDispensable();
140 141

    AddInput(
142
        "CustomDistProbs",
143 144
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
145
        "The i-th element is the probability of the i-th class being sampled.")
146
        .AsDispensable();
147 148 149 150
    AddInput(
        "CustomDistAlias",
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
151
        "The i-th element is the probability of the i-th class being sampled.")
152 153 154 155 156
        .AsDispensable();
    AddInput(
        "CustomDistAliasProbs",
        "(Tensor) It is used in 'CostumDist' sampler. "
        "It is a tensor with shape [num_total_classes]."
T
tianshuo78520a 已提交
157
        "The i-th element is the probability of the i-th class being sampled.")
158 159
        .AsDispensable();

W
wanghaoshuang 已提交
160
    AddOutput("Cost",
W
wanghaoshuang 已提交
161 162 163 164 165 166 167 168 169
              "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
    AddOutput("SampleLogits",
              "An intermediate tensor of shape[batch_size, num_neg_samples + "
              "num_pos_samples]."
              "This tensor is output of forward kernel and used in backward "
              "kernel to compute grads."
              "Given X is  the dot product of input tensor and sampled labels' "
              "weights."
              "Then 'SampleLogits' is sigmoid(X).")
P
pangyoki 已提交
170 171
        .AsIntermediate()
        .AsExtra();
W
wanghaoshuang 已提交
172 173 174 175 176 177
    AddOutput("SampleLabels",
              "An intermediate tensor of shape[batch_size, num_neg_samples + "
              "num_pos_samples]."
              "This tensor is output of forward kernel and used in backward "
              "kernel to compute grads."
              "")
P
pangyoki 已提交
178 179
        .AsIntermediate()
        .AsExtra();
180

W
wanghaoshuang 已提交
181 182 183 184
    AddAttr<int>("num_total_classes",
                 "Total number of classes in all samples.");
    AddAttr<int>("num_neg_samples",
                 "The number of negative classes. The default value is 10.")
W
wanghaoshuang 已提交
185
        .SetDefault(10);
186 187 188 189 190 191 192 193
    AddAttr<int>("sampler",
                 "(int) Which sampler to be used to sample negative class."
                 "0: Uniform; 1: LogUniform; 2: CostumDist.")
        .SetDefault(0);
    AddAttr<int>("seed",
                 "(int) The seed used in sampler. If it is 0, "
                 "the sampler will generate a seed randomly.")
        .SetDefault(0);
194 195
    AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
        .SetDefault(false);
196

T
tangwei12 已提交
197 198
    // for parameter prefetch
    AddAttr<bool>("remote_prefetch", "").SetDefault(false);
P
pangyoki 已提交
199 200 201
    AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.")
        .SetDefault(0)
        .AsExtra();
Q
Qiao Longfei 已提交
202 203
    AddAttr<std::vector<int64_t>>("height_sections",
                                  "Height for each output SelectedRows.")
P
pangyoki 已提交
204 205
        .SetDefault(std::vector<int64_t>({}))
        .AsExtra();
T
tangwei12 已提交
206 207 208 209
    AddAttr<std::vector<std::string>>(
        "epmap",
        "(string vector, default 127.0.0.1:6164)"
        "Server endpoints in the order of input variables for mapping")
P
pangyoki 已提交
210 211
        .SetDefault({})
        .AsExtra();
T
tangwei12 已提交
212 213
    AddAttr<std::vector<std::string>>(
        "table_names",
T
tianshuo78520a 已提交
214
        "(string vector, the split table names that will be fetched from "
T
tangwei12 已提交
215 216
        "parameter server)"
        "in the order of input variables for mapping")
P
pangyoki 已提交
217 218
        .SetDefault({})
        .AsExtra();
T
tangwei12 已提交
219

W
wanghaoshuang 已提交
220 221 222 223
    AddAttr<std::vector<int>>("custom_neg_classes",
                              "This attribute only be used in unitest. Classes "
                              "in this list wiil be used as negative classes "
                              "for every samples. Under normal conditions, "
Y
Yang Yu 已提交
224
                              "user should avoid setting this attribute.")
P
pangyoki 已提交
225 226 227 228 229 230
        .SetDefault({})
        .AsExtra();
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference "
                  "only, false for training.")
        .SetDefault(false);
W
wanghaoshuang 已提交
231
    AddComment(R"DOC(
M
minqiyang 已提交
232 233 234
Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized
statistical models
Y
Yibing Liu 已提交
235
 <http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf>`_.
W
wanghaoshuang 已提交
236
By default this operator uses a uniform distribution for sampling.
W
wanghaoshuang 已提交
237 238 239 240
)DOC");
  }
};

241 242 243 244
template <typename T>
class NCEGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
245
  void Apply(GradOpPtr<T> op) const override {
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Label", this->Input("Label"));
    op->SetInput("Bias", this->Input("Bias"));
    op->SetInput("Weight", this->Input("Weight"));
    op->SetInput("SampleLogits", this->Output("SampleLogits"));
    op->SetInput("SampleLabels", this->Output("SampleLabels"));
    op->SetInput("SampleWeight", this->Input("SampleWeight"));
    op->SetInput("CustomDistProbs", this->Input("CustomDistProbs"));
    op->SetInput("CustomDistAlias", this->Input("CustomDistAlias"));
    op->SetInput("CustomDistAliasProbs", this->Input("CustomDistAliasProbs"));
    op->SetInput(framework::GradVarName("Cost"), this->OutputGrad("Cost"));
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
    op->SetAttrMap(this->Attrs());
  }
};

W
wanghaoshuang 已提交
265 266 267 268
class NCEOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

269
  void InferShape(framework::InferShapeContext *ctx) const override {
270 271 272 273 274 275 276 277
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput("SampleLogits"), "Input", "SampleLogits",
                   "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput("SampleLabels"), "Input", "SampleLabels",
                   "nce_grad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Cost")), "Input",
                   framework::GradVarName("Cost"), "nce_grad");
W
wanghaoshuang 已提交
278

W
wanghaoshuang 已提交
279 280
    auto x_dims = ctx->GetInputDim("Input");
    auto x_grad_name = framework::GradVarName("Input");
W
wanghaoshuang 已提交
281 282 283 284
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }

W
wanghaoshuang 已提交
285 286
    auto w_dims = ctx->GetInputDim("Weight");
    auto w_grad_name = framework::GradVarName("Weight");
W
wanghaoshuang 已提交
287 288 289 290
    if (ctx->HasOutput(w_grad_name)) {
      ctx->SetOutputDim(w_grad_name, w_dims);
    }

W
wanghaoshuang 已提交
291
    auto bias_grad_name = framework::GradVarName("Bias");
W
wanghaoshuang 已提交
292
    if (ctx->HasOutput(bias_grad_name)) {
W
wanghaoshuang 已提交
293
      auto bias_dims = ctx->GetInputDim("Bias");
W
wanghaoshuang 已提交
294 295 296
      ctx->SetOutputDim(bias_grad_name, bias_dims);
    }
  }
W
wanghaoshuang 已提交
297 298

 protected:
299
  framework::OpKernelType GetExpectedKernelType(
300
      const framework::ExecutionContext &ctx) const override {
301 302 303
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        platform::CPUPlace());
W
wanghaoshuang 已提交
304
  }
W
wanghaoshuang 已提交
305 306
};

307 308
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
309
  void operator()(framework::InferVarTypeContext *ctx) const override {
310
    auto weight_grad = framework::GradVarName("Weight");
311

M
minqiyang 已提交
312
    auto attr = ctx->GetAttr("is_sparse");
313
    bool is_sparse = BOOST_GET(bool, attr);
314
    if (is_sparse) {
315
      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
M
minqiyang 已提交
316
              << " is set to SelectedRows";
317
      ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
318
    } else {
319
      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
M
minqiyang 已提交
320
              << " is set to LoDTensor";
321
      ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR);
322
    }
323
    ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input"));
324 325 326
  }
};

327
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInferer, "Bias");
328

W
wanghaoshuang 已提交
329 330 331 332
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
333 334 335
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
                  ops::NCEGradOpMaker<paddle::framework::OpDesc>,
                  ops::NCEGradOpMaker<paddle::imperative::OpBase>);
336
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference,
337
                  ops::NCEGradOpNoNeedBufferVarInferer);
W
wanghaoshuang 已提交
338 339
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
                       ops::NCEKernel<paddle::platform::CPUPlace, double>);
W
wanghaoshuang 已提交
340
REGISTER_OP_CPU_KERNEL(nce_grad,
W
wanghaoshuang 已提交
341 342
                       ops::NCEGradKernel<paddle::platform::CPUPlace, float>,
                       ops::NCEGradKernel<paddle::platform::CPUPlace, double>);