softmax_loss_op.cpp 1.7 KB
Newer Older
C
chengtbf 已提交
1
#include "oneflow/core/operator/softmax_loss_op.h"
W
Will Zhang 已提交
2
#include "oneflow/core/common/data_type.h"
C
chengtbf 已提交
3 4 5

namespace oneflow {

W
Will Zhang 已提交
6 7
void SoftmaxLossOp::InitFromOpConf() {
  CHECK(op_conf().has_softmax_loss_conf());
8
  EnrollInputBn("prediction");
C
chengtbf 已提交
9
  EnrollInputBn("label", false);
C
chengtbf 已提交
10 11 12 13 14 15 16 17 18
  EnrollDataTmpBn("prob");
  EnrollDataTmpBn("tmp_1D");
  EnrollOutputBn("loss", false);
}

const PbMessage& SoftmaxLossOp::GetSpecialConf() const {
  return op_conf().softmax_loss_conf();
}

W
Will Zhang 已提交
19
void SoftmaxLossOp::InferBlobDescs(
W
willzhang4a58 已提交
20
    std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
W
Will Zhang 已提交
21
    const ParallelContext* parallel_ctx) {
W
Will Zhang 已提交
22 23 24 25 26 27 28 29 30
  const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
  const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
  CHECK_EQ(pred_blob_desc->has_data_id(), label_blob_desc->has_data_id());
  CHECK(IsIntegral(label_blob_desc->data_type()));
  CHECK_EQ(pred_blob_desc->shape().NumAxes(), 2);
  CHECK_EQ(label_blob_desc->shape(), Shape({pred_blob_desc->shape().At(0)}));
  // loss
  BlobDesc* loss_blob_desc = GetBlobDesc4BnInOp("loss");
  loss_blob_desc->mut_shape() = Shape({1});
W
willzhang4a58 已提交
31
  loss_blob_desc->set_data_type(pred_blob_desc->data_type());
W
Will Zhang 已提交
32 33 34 35
  loss_blob_desc->set_has_data_id(false);
  // tmp_1D
  BlobDesc* tmp_1D_blob_desc = GetBlobDesc4BnInOp("tmp_1D");
  tmp_1D_blob_desc->mut_shape() = Shape({pred_blob_desc->shape().At(0)});
W
willzhang4a58 已提交
36
  tmp_1D_blob_desc->set_data_type(pred_blob_desc->data_type());
W
Will Zhang 已提交
37 38 39 40
  tmp_1D_blob_desc->set_has_data_id(false);
  // prob
  BlobDesc* prob_blob_desc = GetBlobDesc4BnInOp("prob");
  prob_blob_desc->mut_shape() = Shape(pred_blob_desc->shape());
W
willzhang4a58 已提交
41
  prob_blob_desc->set_data_type(pred_blob_desc->data_type());
W
Will Zhang 已提交
42
  prob_blob_desc->set_has_data_id(false);
C
chengtbf 已提交
43 44 45 46 47
}

REGISTER_OP(OperatorConf::kSoftmaxLossConf, SoftmaxLossOp);

}  // namespace oneflow