提交 eaba2e2e 编写于 作者: Y Yu Yang

Expose Evaluator API

上级 1e6c87bd
......@@ -62,14 +62,14 @@ def main():
train_data_generator = input_order_converter(
read_from_mnist(train_file))
for batch_id, data_batch in enumerate(
generator_to_batch(train_data_generator, 256)):
generator_to_batch(train_data_generator, 2048)):
trainRole = updater.startBatch(len(data_batch))
def update_callback(param):
def updater_callback(param):
updater.update(param)
m.forwardBackward(
converter(data_batch), outArgs, trainRole, update_callback)
converter(data_batch), outArgs, trainRole, updater_callback)
cost_vec = outArgs.getSlotValue(0)
cost_vec = cost_vec.copyToNumpyMat()
......
set(API_SOURCES
Arguments.cpp
ConfigParser.cpp
Evaluator.cpp
GradientMachine.cpp
Matrix.cpp
Parameter.cpp
......@@ -63,6 +64,15 @@ install(DIRECTORY ${PROJ_ROOT}/paddle/dist/
add_custom_target(python_api_wheel ALL DEPENDS
${PROJ_ROOT}/paddle/dist/.timestamp)
add_dependencies(python_api_wheel python_swig_sources
paddle_parameter
paddle_math
paddle_utils
paddle_gserver
paddle_pserver
paddle_trainer
paddle_api
paddle_cuda)
if(WITH_TESTING)
add_subdirectory(test)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sstream>
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
Evaluator::Evaluator() : m(new EvaluatorPrivate()) {}
Evaluator::~Evaluator() { delete m; }
void Evaluator::start() { m->rawPtr->start(); }
void Evaluator::finish() { m->rawPtr->finish(); }
std::string Evaluator::toString() {
std::ostringstream sout;
m->rawPtr->printStats(sout);
return sout.str();
}
......@@ -162,3 +162,13 @@ SequenceGenerator* GradientMachine::asSequenceGenerator(
r->setBeamSize(beam_size);
return r;
}
Evaluator* GradientMachine::makeEvaluator() {
auto ev = new Evaluator();
ev->m->rawPtr = m->machine->makeEvaluator();
return ev;
}
void GradientMachine::eval(Evaluator* evaluator) {
m->machine->eval(evaluator->m->rawPtr);
}
......@@ -97,6 +97,7 @@ namespace std {
%rename(__setitem__) Vector::set;
%rename(__len__) Vector::getSize;
%rename(__call__) ParameterTraverseCallback::apply;
%rename(__repr__) Evaluator::toString;
%apply (float* INPLACE_ARRAY2, int DIM1, int DIM2) {
(float* data, int dim1, int dim2)
......@@ -167,6 +168,7 @@ namespace std {
%newobject GradientMachine::asSequenceGenerator;
%newobject GradientMachine::getParameter;
%newobject GradientMachine::getLayerOutput;
%newobject GradientMachine::makeEvaluator;
%newobject TrainerConfig::createFromTrainerConfigFile;
%newobject TrainerConfig::getModelConfig;
%newobject TrainerConfig::getOptimizationConfig;
......
......@@ -685,7 +685,7 @@ private:
};
class SequenceGenerator;
class Evaluator;
struct GradientMachinePrivate;
class GradientMachine {
private:
......@@ -770,6 +770,10 @@ public:
size_t max_length = 100UL,
size_t beam_size = -1UL);
Evaluator* makeEvaluator();
void eval(Evaluator* evaluator);
private:
GradientMachinePrivate* m;
......@@ -809,6 +813,27 @@ private:
ParameterUpdaterPrivate* m;
};
struct EvaluatorPrivate;
class Evaluator {
private:
Evaluator();
DISABLE_COPY_AND_ASSIGN(Evaluator);
public:
~Evaluator();
void start();
void finish();
std::string toString();
private:
EvaluatorPrivate* m;
friend class GradientMachine;
};
struct TrainerPrivate;
class Trainer {
private:
......
......@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once
#include <memory>
#include "PaddleAPI.h"
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h"
#include "paddle/parameter/ParameterUpdaterBase.h"
#include "paddle/trainer/TrainerConfigHelper.h"
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
......@@ -88,3 +88,10 @@ struct ParameterPrivate {
}
}
};
struct EvaluatorPrivate {
paddle::Evaluator* rawPtr;
EvaluatorPrivate() : rawPtr(nullptr) {}
~EvaluatorPrivate() { delete rawPtr; }
};
......@@ -29,7 +29,7 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
ParameterUpdater::~ParameterUpdater() { delete m; }
void ParameterUpdater::init(const GradientMachine &gm) {
m->updater->init(gm.m->machine->getParameters());
m->updater->init(gm.m->machine->getNonStaticParameters());
}
void ParameterUpdater::startPass() { m->updater->startPass(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册