提交 05e8a26b 编写于 作者: C caoying03

add unittest.

上级 44ae44da
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CrossEntropyOverBeam.h"
namespace paddle {
REGISTER_LAYER(cross_entropy_over_beam, CrossEntropyOverBeam);
bool CrossEntropyOverBeam::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
setNeedSequenceInfo(false);
return true;
}
void CrossEntropyOverBeam::forward(PassType passType) {}
void CrossEntropyOverBeam::backward(const UpdateCallback& callback) {}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "CrossEntropyOverBeam.h"
#include "Layer.h"
namespace paddle {
class CrossEntropyOverBeam : public Layer {
public:
explicit CrossEntropyOverBeam(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
};
} // namespace paddle
...@@ -34,6 +34,12 @@ add_unittest_without_exec(test_CRFLayerGrad ...@@ -34,6 +34,12 @@ add_unittest_without_exec(test_CRFLayerGrad
add_test(NAME test_CRFLayerGrad add_test(NAME test_CRFLayerGrad
COMMAND test_CRFLayerGrad) COMMAND test_CRFLayerGrad)
################ test_CrossEntropyOverBeam ####################
add_unittest_without_exec(test_CrossEntropyOverBeam
test_CrossEntropyOverBeamGrad.cpp
LayerGradUtil.cpp)
add_test(NAME test_CrossEntropyOverBeam
COMMAND test_CrossEntropyOverBeam)
add_unittest_without_exec(test_ActivationGrad add_unittest_without_exec(test_ActivationGrad
test_ActivationGrad.cpp test_ActivationGrad.cpp
......
...@@ -388,14 +388,23 @@ void initDataLayer(TestConfig testConf, ...@@ -388,14 +388,23 @@ void initDataLayer(TestConfig testConf,
data.grad->zeroMem(); data.grad->zeroMem();
break; break;
case INPUT_SELF_DEFINE_DATA: { case INPUT_SELF_DEFINE_DATA: {
size_t height = testConf.inputDefs[i].selfDefinedData->getHeight(); if (testConf.inputDefs[i].ids.size()) {
size_t width = testConf.inputDefs[i].selfDefinedData->getWidth(); data.ids = IVector::create(testConf.inputDefs[i].ids.size(), useGpu);
CHECK_GT(static_cast<int>(height), 0); data.ids->copyFrom(testConf.inputDefs[i].ids.data(),
CHECK_GT(static_cast<int>(width), 0); testConf.inputDefs[i].ids.size());
data.value = Matrix::create(height, width, false, useGpu); } else if (testConf.inputDefs[i].selfDefinedData) {
data.grad = Matrix::create(height, width, false, useGpu); size_t height = testConf.inputDefs[i].selfDefinedData->getHeight();
data.value->copyFrom(*testConf.inputDefs[i].selfDefinedData); size_t width = testConf.inputDefs[i].selfDefinedData->getWidth();
data.grad->zeroMem(); CHECK_GT(static_cast<int>(height), 0);
CHECK_GT(static_cast<int>(width), 0);
data.value = Matrix::create(height, width, false, useGpu);
data.grad = Matrix::create(height, width, false, useGpu);
data.value->copyFrom(*testConf.inputDefs[i].selfDefinedData);
data.grad->zeroMem();
} else {
LOG(FATAL) << "No self-defined data are given.";
return;
}
const std::vector<int>& labelSeqStartPositions = const std::vector<int>& labelSeqStartPositions =
testConf.inputDefs[i].labelSeqStartPositions; testConf.inputDefs[i].labelSeqStartPositions;
......
...@@ -68,6 +68,7 @@ struct InputDef { ...@@ -68,6 +68,7 @@ struct InputDef {
std::vector<int> labelInitValue; std::vector<int> labelInitValue;
std::vector<int> labelSeqStartPositions; std::vector<int> labelSeqStartPositions;
std::vector<int> labelSubSeqStartPositions; std::vector<int> labelSubSeqStartPositions;
std::vector<int> ids;
MatrixPtr selfDefinedData; MatrixPtr selfDefinedData;
InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) {
...@@ -95,6 +96,23 @@ struct InputDef { ...@@ -95,6 +96,23 @@ struct InputDef {
isStatic = false; isStatic = false;
} }
InputDef(InputType type,
string nameIn,
std::vector<int> ids,
std::vector<int> selfDefinedSeqStartPos = {},
std::vector<int> selfDefinedSubSeqStartPos = {})
: labelSeqStartPositions(selfDefinedSeqStartPos),
labelSubSeqStartPositions(selfDefinedSubSeqStartPos),
ids(ids) {
selfDefinedData = nullptr;
inputType = type;
name = nameIn;
dim = 0;
sparse = {""};
paraSize = 0;
isStatic = false;
}
InputDef(InputType type, InputDef(InputType type,
string nameIn, string nameIn,
size_t dimIn, size_t dimIn,
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sstream>
#include <gtest/gtest.h>
#include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h"
#include "paddle/trainer/Trainer.h"
#include "LayerGradUtil.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle; // NOLINT
DECLARE_int32(gpu_id);
DECLARE_bool(thread_local_rand_use_global_seed);
struct SingleBeamExpansion {
vector<int> seqStartPos;
vector<int> subSeqStartPos;
vector<real> candidateScores;
// TODO(caoying): store this into Argument.ids
vector<real> selectedIndices;
vector<int> groundTruth;
};
void genRandomBeamExpansion(size_t expansionCount,
vector<SingleBeamExpansion>& beamExpansions) {
beamExpansions.clear();
}
void testCrossEntropyOverBeam() {
const size_t expansionCount = 3;
vector<SingleBeamExpansion> beams;
genRandomBeamExpansion(expansionCount, beams);
for (size_t i = 0; i < beams.size(); ++i) {
const SingleBeamExpansion& beam = beams[i];
// create scores for all the candidates
MatrixPtr candidateScorePtr =
Matrix::create(beam.candidateScores.size(), 1, false, false);
candidateScorePtr->copyFrom(candidateScores.data(), candidateScores.size());
ostringstream paramName;
paramName << "candidate_scores_" << i;
beam.subSeqStartPos.size()
? config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA,
ostr.str(),
candidateScorePtr,
beam.seqStartPos,
beam.subSeqStartPos})
: config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA,
ostr.str(),
candidateScorePtr,
beam.seqStartPos});
// create indices for the selected candidates
// create the ground truth
}
}
TestConfig config;
config.layerConfig.set_type("cross_entropy_over_beam");
// testLayerGrad(
// config, "cross_entropy_over_beam", seqNum, false, useGpu, false);
}
TEST(Layer, CrossEntropyOverBeam) {
for (bool useGpu : {false, true}) testCrossEntropyOverBeam(useGpu);
}
int main(int argc, char** argv) {
initMain(argc, argv);
hl_start();
hl_init(FLAGS_gpu_id);
FLAGS_thread_local_rand_use_global_seed = true;
srand(1);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册