提交 ad6cb60d 编写于 作者: Y Yu Yang

Merge branch 'feature/clean_gradient_machine_start' into feature/mnist_train_api

...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <set>
#include <vector> #include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
...@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator { ...@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
std::vector<Segment> labelSegments_; std::vector<Segment> labelSegments_;
std::vector<Segment> outputSegments_; std::vector<Segment> outputSegments_;
std::set<int> excludedChunkTypes_;
public: public:
virtual void init(const EvaluatorConfig& config) { virtual void init(const EvaluatorConfig& config) {
...@@ -105,6 +107,10 @@ public: ...@@ -105,6 +107,10 @@ public:
} }
CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config"; CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config";
otherChunkType_ = numChunkTypes_ = config.num_chunk_types(); otherChunkType_ = numChunkTypes_ = config.num_chunk_types();
// the chunks of types in excludedChunkTypes_ will not be counted
auto& tmp = config.excluded_chunk_types();
excludedChunkTypes_.insert(tmp.begin(), tmp.end());
} }
virtual void start() { virtual void start() {
...@@ -156,7 +162,8 @@ public: ...@@ -156,7 +162,8 @@ public:
getSegments(label, length, labelSegments_); getSegments(label, length, labelSegments_);
size_t i = 0, j = 0; size_t i = 0, j = 0;
while (i < outputSegments_.size() && j < labelSegments_.size()) { while (i < outputSegments_.size() && j < labelSegments_.size()) {
if (outputSegments_[i] == labelSegments_[j]) { if (outputSegments_[i] == labelSegments_[j] &&
excludedChunkTypes_.count(outputSegments_[i].type) != 1) {
++numCorrect_; ++numCorrect_;
} }
if (outputSegments_[i].end < labelSegments_[j].end) { if (outputSegments_[i].end < labelSegments_[j].end) {
...@@ -168,8 +175,12 @@ public: ...@@ -168,8 +175,12 @@ public:
++j; ++j;
} }
} }
numLabelSegments_ += labelSegments_.size(); for (auto& segment : labelSegments_) {
numOutputSegments_ += outputSegments_.size(); if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_;
}
for (auto& segment : outputSegments_) {
if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_;
}
} }
void getSegments(int* label, int length, std::vector<Segment>& segments) { void getSegments(int* label, int length, std::vector<Segment>& segments) {
......
...@@ -212,11 +212,7 @@ public: ...@@ -212,11 +212,7 @@ public:
* @note This function will only been implemented and used in a * @note This function will only been implemented and used in a
* multithreaded environment. * multithreaded environment.
*/ */
virtual void start(const TrainerConfig& config, virtual void start() {}
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
}
/** /**
* @brief check each work-thread whether is failed/error/finish, * @brief check each work-thread whether is failed/error/finish,
......
...@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, ...@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); } TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() { void TrainerThread::start() {
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr); gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); })); computeThread_.reset(new std::thread([this]() { computeThread(); }));
......
...@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() { ...@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
} }
} }
void MultiNetwork::start(const TrainerConfig& config, void MultiNetwork::start() {
DataProviderPtr dataProvider) {
for (auto& subNetwork : subNetworks_) { for (auto& subNetwork : subNetworks_) {
subNetwork->start(config, dataProvider); subNetwork->start();
} }
} }
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
return subNetworks_; return subNetworks_;
} }
virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider); virtual void start();
virtual void finish(); virtual void finish();
......
...@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs, ...@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
backward(callback); backward(callback);
} }
void ParallelNeuralNetwork::start(const TrainerConfig& config, void ParallelNeuralNetwork::start() {
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread->start(); thread->start();
} }
......
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
PassType passType, PassType passType,
const UpdateCallback &callback = NULL); const UpdateCallback &callback = NULL);
virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider); virtual void start();
void addComputeThread(int deviceId); void addComputeThread(int deviceId);
......
...@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) { ...@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]); parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
} }
} }
gradientMachine->start(trainer.getConfig(), nullptr); gradientMachine->start();
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN); gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
for (size_t i = 0; i < in.outGrads.size(); i++) { for (size_t i = 0; i < in.outGrads.size(); i++) {
// If the all the layers in the config have no parameters, also // If the all the layers in the config have no parameters, also
......
...@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer { ...@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
public: public:
void startTrain() { void startTrain() {
GradientMachine& gm = *this->trainerInternal_.getGradientMachine(); GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
gm.start(this->getConfig(), dataProvider_); gm.start();
} }
void finishTrain() { void finishTrain() {
......
...@@ -257,7 +257,7 @@ void Tester::test() { ...@@ -257,7 +257,7 @@ void Tester::test() {
CHECK(testDataProvider_) << "TestData is not specified"; CHECK(testDataProvider_) << "TestData is not specified";
testDataProvider_->setSkipShuffle(); testDataProvider_->setSkipShuffle();
testDataProvider_->reset(); testDataProvider_->reset();
gradientMachine_->start(*config_, testDataProvider_); gradientMachine_->start();
// For evaluation // For evaluation
std::vector<std::string> modelList; std::vector<std::string> modelList;
......
...@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) { ...@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
} }
real Trainer::checkGradient() { real Trainer::checkGradient() {
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start();
std::vector<ParameterPtr>& parameters = std::vector<ParameterPtr>& parameters =
trainerInternal_.getGradientMachine()->getNonStaticParameters(); trainerInternal_.getGradientMachine()->getNonStaticParameters();
DataBatch dataBatch; DataBatch dataBatch;
...@@ -390,7 +390,7 @@ void Trainer::startTrain() { ...@@ -390,7 +390,7 @@ void Trainer::startTrain() {
dataProvider_->reset(); dataProvider_->reset();
} }
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start();
} }
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); } void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }
......
...@@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) { ...@@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch); trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
CHECK(dataBatch.getSize()) << "No data from data provider"; CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams(); vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); trainer.getGradientMachine()->start();
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
trainer.getGradientMachine()->forwardBackward( trainer.getGradientMachine()->forwardBackward(
inArgs, &Data.outArgs, PASS_TRAIN); inArgs, &Data.outArgs, PASS_TRAIN);
......
...@@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) { ...@@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) {
CHECK(dataBatch.getSize()) << "No data from data provider"; CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams(); vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); trainer.getGradientMachine()->start();
trainer.getGradientMachine()->forwardBackward( trainer.getGradientMachine()->forwardBackward(
inArgs, &data.outArgs, PASS_TRAIN); inArgs, &data.outArgs, PASS_TRAIN);
......
...@@ -433,8 +433,10 @@ message EvaluatorConfig { ...@@ -433,8 +433,10 @@ message EvaluatorConfig {
repeated string input_layers = 3; repeated string input_layers = 3;
// Used by ChunkEvaluator // Used by ChunkEvaluator
optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES" // one of "IOB", "IOE", "IOBES"
optional int32 num_chunk_types = 5; // number of chunk types other than "other" optional string chunk_scheme = 4;
// number of chunk types other than "other"
optional int32 num_chunk_types = 5;
// Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator
// For multi binary labels: true if output > classification_threshold // For multi binary labels: true if output > classification_threshold
...@@ -453,6 +455,10 @@ message EvaluatorConfig { ...@@ -453,6 +455,10 @@ message EvaluatorConfig {
// whether to delimit the sequence in the seq_text_printer // whether to delimit the sequence in the seq_text_printer
optional bool delimited = 11 [default = true]; optional bool delimited = 11 [default = true];
// Used by ChunkEvaluator
// chunk of these types are not counted
repeated int32 excluded_chunk_types = 12;
} }
message LinkConfig { message LinkConfig {
......
...@@ -1240,7 +1240,8 @@ def Evaluator( ...@@ -1240,7 +1240,8 @@ def Evaluator(
dict_file=None, dict_file=None,
result_file=None, result_file=None,
num_results=None, num_results=None,
delimited=None, ): delimited=None,
excluded_chunk_types=None, ):
evaluator = g_config.model_config.evaluators.add() evaluator = g_config.model_config.evaluators.add()
evaluator.type = type evaluator.type = type
evaluator.name = MakeLayerNameInSubmodel(name) evaluator.name = MakeLayerNameInSubmodel(name)
...@@ -1269,6 +1270,9 @@ def Evaluator( ...@@ -1269,6 +1270,9 @@ def Evaluator(
if delimited is not None: if delimited is not None:
evaluator.delimited = delimited evaluator.delimited = delimited
if excluded_chunk_types:
evaluator.excluded_chunk_types.extend(excluded_chunk_types)
class LayerBase(object): class LayerBase(object):
def __init__( def __init__(
......
...@@ -57,7 +57,8 @@ def evaluator(*attrs): ...@@ -57,7 +57,8 @@ def evaluator(*attrs):
return impl return impl
def evaluator_base(input, def evaluator_base(
input,
type, type,
label=None, label=None,
weight=None, weight=None,
...@@ -69,7 +70,8 @@ def evaluator_base(input, ...@@ -69,7 +70,8 @@ def evaluator_base(input,
dict_file=None, dict_file=None,
result_file=None, result_file=None,
num_results=None, num_results=None,
delimited=None): delimited=None,
excluded_chunk_types=None, ):
""" """
Evaluator will evaluate the network status while training/testing. Evaluator will evaluate the network status while training/testing.
...@@ -127,7 +129,8 @@ def evaluator_base(input, ...@@ -127,7 +129,8 @@ def evaluator_base(input,
positive_label=positive_label, positive_label=positive_label,
dict_file=dict_file, dict_file=dict_file,
result_file=result_file, result_file=result_file,
delimited=delimited) delimited=delimited,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
...@@ -330,7 +333,8 @@ def chunk_evaluator( ...@@ -330,7 +333,8 @@ def chunk_evaluator(
label, label,
chunk_scheme, chunk_scheme,
num_chunk_types, num_chunk_types,
name=None, ): name=None,
excluded_chunk_types=None, ):
""" """
Chunk evaluator is used to evaluate segment labelling accuracy for a Chunk evaluator is used to evaluate segment labelling accuracy for a
sequence. It calculates the chunk detection F1 score. sequence. It calculates the chunk detection F1 score.
...@@ -376,6 +380,8 @@ def chunk_evaluator( ...@@ -376,6 +380,8 @@ def chunk_evaluator(
:param num_chunk_types: number of chunk types other than "other" :param num_chunk_types: number of chunk types other than "other"
:param name: The Evaluator name, it is optional. :param name: The Evaluator name, it is optional.
:type name: basename|None :type name: basename|None
:param excluded_chunk_types: chunks of these types are not considered
:type excluded_chunk_types: list of integer|None
""" """
evaluator_base( evaluator_base(
name=name, name=name,
...@@ -383,7 +389,8 @@ def chunk_evaluator( ...@@ -383,7 +389,8 @@ def chunk_evaluator(
input=input, input=input,
label=label, label=label,
chunk_scheme=chunk_scheme, chunk_scheme=chunk_scheme,
num_chunk_types=num_chunk_types) num_chunk_types=num_chunk_types,
excluded_chunk_types=excluded_chunk_types, )
@evaluator(EvaluatorAttribute.FOR_UTILS) @evaluator(EvaluatorAttribute.FOR_UTILS)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册