From 69c613fb5af7ff5b289fad7525d01a2dce37f749 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 20 Sep 2017 14:08:24 +0800 Subject: [PATCH] refine Layer.cpp for some CostLayer --- paddle/gserver/layers/Layer.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index f11875b17..e95f42c86 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -14,7 +14,8 @@ limitations under the License. */ #include "paddle/utils/Util.h" -#include "Layer.h" +#include "CostLayer.h" +#include "ValidationLayer.h" #include "paddle/math/SparseMatrix.h" #include "paddle/utils/Error.h" #include "paddle/utils/Logging.h" @@ -93,6 +94,20 @@ ClassRegistrar Layer::registrar_; LayerPtr Layer::create(const LayerConfig& config) { std::string type = config.type(); + + // NOTE: As following types have illegal character '-', + // they can not use REGISTER_LAYER to registrar. + // Besides, to fit with old training models, + // they can not use '_' instead. + if (type == "multi-class-cross-entropy") + return LayerPtr(new MultiClassCrossEntropy(config)); + else if (type == "rank-cost") + return LayerPtr(new RankingCost(config)); + else if (type == "auc-validation") + return LayerPtr(new AucValidation(config)); + else if (type == "pnpair-validation") + return LayerPtr(new PnpairValidation(config)); + return LayerPtr(registrar_.createByType(config.type(), config)); } -- GitLab