提交 ab0e4857 编写于 作者: W willzhang4a58

mlloss layer desc

上级 e38f9d9d
......@@ -56,6 +56,12 @@ message SoftmaxLayerConf {
string out = 2;
}
message MultinomialLogisticLossLayerConf {
string data = 1;
string label = 2;
string loss = 3;
};
message LayerConf {
string name = 1;
oneof conf_for_specified_layer {
......@@ -66,5 +72,6 @@ message LayerConf {
PoolingLayerConf pooling_layer_conf = 103;
ReluLayerConf relu_layer_conf = 104;
SoftmaxLayerConf softmax_layer_conf = 105;
MultinomialLogisticLossLayerConf multinomial_logistic_loss_layer_conf = 106;
}
}
#include "layer/multinomial_logistic_loss_layer.h"
namespace oneflow {
void MLLossDataBlobDescSet::Init(const std::string& layer_name) {
DataBlobDescSet::Init();
RegisterInputBlobPptr(layer_name + ".data", &data_);
RegisterInputDiffBlobPptr(layer_name + ".data_diff", &data_diff_);
RegisterInputBlobPptr(layer_name + ".label", &label_);
RegisterInputDiffBlobPptr(layer_name + ".label_diff", &label_diff_);
RegisterOutputBlobPptr(layer_name + ".loss", &loss_);
RegisterDataTmpBlobPptr(layer_name + ".loss_buffer", &loss_buffer_);
}
void MultinomialLogisticLossLayer::Init(const LayerConf& layer_conf) {
mutable_layer_name() = layer_conf.name();
CHECK(layer_conf.has_multinomial_logistic_loss_layer_conf());
layer_conf_ = layer_conf.multinomial_logistic_loss_layer_conf();
auto data_ptr = new MLLossDataBlobDescSet();
data_ptr->Init(layer_name());
mutable_data_blob_desc_set().reset(data_ptr);
auto model_ptr = new MLLossModelBlobDescSet();
model_ptr->Init(layer_name());
mutable_model_blob_desc_set().reset(model_ptr);
}
} // namespace oneflow
#ifndef LAYER_MULTINOMIAL_LOGISTIC_LOSS_LAYER_H_
#define LAYER_MULTINOMIAL_LOGISTIC_LOSS_LAYER_H_
#include "layer/base_layer_desc.h"
namespace oneflow {
// MLLoss = MultinomialLogisticLoss
class MLLossDataBlobDescSet final : public DataBlobDescSet {
public:
DISALLOW_COPY_AND_MOVE(MLLossDataBlobDescSet);
MLLossDataBlobDescSet() = default;
~MLLossDataBlobDescSet() = default;
void Init(const std::string& layer_name);
private:
BlobDescriptor* data_;
BlobDescriptor* data_diff_;
BlobDescriptor* label_;
BlobDescriptor* label_diff_;
BlobDescriptor* loss_;
BlobDescriptor* loss_buffer_;
};
class MLLossModelBlobDescSet final : public ModelBlobDescSet {
public:
DISALLOW_COPY_AND_MOVE(MLLossModelBlobDescSet);
MLLossModelBlobDescSet() = default;
~MLLossModelBlobDescSet() = default;
void Init(const std::string& layer_name) {
ModelBlobDescSet::Init();
}
private:
};
class MultinomialLogisticLossLayer : public BaseLayerDesc {
public:
DISALLOW_COPY_AND_MOVE(MultinomialLogisticLossLayer);
MultinomialLogisticLossLayer() = default;
~MultinomialLogisticLossLayer() = default;
void Init(const LayerConf& layer_conf) override;
private:
MultinomialLogisticLossLayerConf layer_conf_;
};
} // namespace oneflow
#endif // LAYER_MULTINOMIAL_LOGISTIC_LOSS_LAYER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册