提交 ad6cb60d 编写于 作者: Y Yu Yang

Merge branch 'feature/clean_gradient_machine_start' into feature/mnist_train_api

......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <set>
#include <vector>
#include "paddle/math/Vector.h"
......@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
std::vector<Segment> labelSegments_;
std::vector<Segment> outputSegments_;
std::set<int> excludedChunkTypes_;
public:
virtual void init(const EvaluatorConfig& config) {
......@@ -105,6 +107,10 @@ public:
}
CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config";
otherChunkType_ = numChunkTypes_ = config.num_chunk_types();
// the chunks of types in excludedChunkTypes_ will not be counted
auto& tmp = config.excluded_chunk_types();
excludedChunkTypes_.insert(tmp.begin(), tmp.end());
}
virtual void start() {
......@@ -156,7 +162,8 @@ public:
getSegments(label, length, labelSegments_);
size_t i = 0, j = 0;
while (i < outputSegments_.size() && j < labelSegments_.size()) {
if (outputSegments_[i] == labelSegments_[j]) {
if (outputSegments_[i] == labelSegments_[j] &&
excludedChunkTypes_.count(outputSegments_[i].type) != 1) {
++numCorrect_;
}
if (outputSegments_[i].end < labelSegments_[j].end) {
......@@ -168,8 +175,12 @@ public:
++j;
}
}
numLabelSegments_ += labelSegments_.size();
numOutputSegments_ += outputSegments_.size();
for (auto& segment : labelSegments_) {
if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_;
}
for (auto& segment : outputSegments_) {
if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_;
}
}
void getSegments(int* label, int length, std::vector<Segment>& segments) {
......
......@@ -212,11 +212,7 @@ public:
* @note This function will only been implemented and used in a
* multithreaded environment.
*/
virtual void start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
}
virtual void start() {}
/**
* @brief check each work-thread whether is failed/error/finish,
......
......@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() {
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr);
gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); }));
......
......@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
}
}
void MultiNetwork::start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
void MultiNetwork::start() {
for (auto& subNetwork : subNetworks_) {
subNetwork->start(config, dataProvider);
subNetwork->start();
}
}
......
......@@ -54,7 +54,7 @@ public:
return subNetworks_;
}
virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider);
virtual void start();
virtual void finish();
......
......@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
backward(callback);
}
void ParallelNeuralNetwork::start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
void ParallelNeuralNetwork::start() {
for (auto& thread : threads_) {
thread->start();
}
......
......@@ -56,7 +56,7 @@ public:
PassType passType,
const UpdateCallback &callback = NULL);
virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider);
virtual void start();
void addComputeThread(int deviceId);
......
......@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
}
}
gradientMachine->start(trainer.getConfig(), nullptr);
gradientMachine->start();
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
for (size_t i = 0; i < in.outGrads.size(); i++) {
// If the all the layers in the config have no parameters, also
......
......@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
public:
void startTrain() {
GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
gm.start(this->getConfig(), dataProvider_);
gm.start();
}
void finishTrain() {
......
......@@ -257,7 +257,7 @@ void Tester::test() {
CHECK(testDataProvider_) << "TestData is not specified";
testDataProvider_->setSkipShuffle();
testDataProvider_->reset();
gradientMachine_->start(*config_, testDataProvider_);
gradientMachine_->start();
// For evaluation
std::vector<std::string> modelList;
......
......@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
}
real Trainer::checkGradient() {
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
trainerInternal_.getGradientMachine()->start();
std::vector<ParameterPtr>& parameters =
trainerInternal_.getGradientMachine()->getNonStaticParameters();
DataBatch dataBatch;
......@@ -390,7 +390,7 @@ void Trainer::startTrain() {
dataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
trainerInternal_.getGradientMachine()->start();
}
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }
......
......@@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
trainer.getGradientMachine()->start();
for (int i = 0; i < 2; ++i) {
trainer.getGradientMachine()->forwardBackward(
inArgs, &Data.outArgs, PASS_TRAIN);
......
......@@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) {
CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
trainer.getGradientMachine()->start();
trainer.getGradientMachine()->forwardBackward(
inArgs, &data.outArgs, PASS_TRAIN);
......
......@@ -433,8 +433,10 @@ message EvaluatorConfig {
repeated string input_layers = 3;
// Used by ChunkEvaluator
optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES"
optional int32 num_chunk_types = 5; // number of chunk types other than "other"
// one of "IOB", "IOE", "IOBES"
optional string chunk_scheme = 4;
// number of chunk types other than "other"
optional int32 num_chunk_types = 5;
// Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator
// For multi binary labels: true if output > classification_threshold
......@@ -453,6 +455,10 @@ message EvaluatorConfig {
// whether to delimit the sequence in the seq_text_printer
optional bool delimited = 11 [default = true];
// Used by ChunkEvaluator
// chunk of these types are not counted
repeated int32 excluded_chunk_types = 12;
}
message LinkConfig {
......
......@@ -1240,7 +1240,8 @@ def Evaluator(
dict_file=None,
result_file=None,
num_results=None,
delimited=None, ):
delimited=None,
excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add()
evaluator.type = type
evaluator.name = MakeLayerNameInSubmodel(name)
......@@ -1269,6 +1270,9 @@ def Evaluator(
if delimited is not None:
evaluator.delimited = delimited
if excluded_chunk_types:
evaluator.excluded_chunk_types.extend(excluded_chunk_types)
class LayerBase(object):
def __init__(
......
......@@ -57,19 +57,21 @@ def evaluator(*attrs):
return impl
def evaluator_base(input,
type,
label=None,
weight=None,
name=None,
chunk_scheme=None,
num_chunk_types=None,
classification_threshold=None,
positive_label=None,
dict_file=None,
result_file=None,
num_results=None,
delimited=None):
def evaluator_base(
input,
type,
label=None,
weight=None,
name=None,
chunk_scheme=None,
num_chunk_types=None,
classification_threshold=None,
positive_label=None,
dict_file=None,
result_file=None,
num_results=None,
delimited=None,
excluded_chunk_types=None, ):
"""
Evaluator will evaluate the network status while training/testing.
......@@ -127,7 +129,8 @@ def evaluator_base(input,
positive_label=positive_label,
dict_file=dict_file,
result_file=result_file,
delimited=delimited)
delimited=delimited,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
......@@ -330,7 +333,8 @@ def chunk_evaluator(
label,
chunk_scheme,
num_chunk_types,
name=None, ):
name=None,
excluded_chunk_types=None, ):
"""
Chunk evaluator is used to evaluate segment labelling accuracy for a
sequence. It calculates the chunk detection F1 score.
......@@ -376,6 +380,8 @@ def chunk_evaluator(
:param num_chunk_types: number of chunk types other than "other"
:param name: The Evaluator name, it is optional.
:type name: basename|None
:param excluded_chunk_types: chunks of these types are not considered
:type excluded_chunk_types: list of integer|None
"""
evaluator_base(
name=name,
......@@ -383,7 +389,8 @@ def chunk_evaluator(
input=input,
label=label,
chunk_scheme=chunk_scheme,
num_chunk_types=num_chunk_types)
num_chunk_types=num_chunk_types,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_UTILS)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部