未验证 提交 d1d21004 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #5331 from dzhwinter/feature/evaluator

Feature/evaluator
## Evaluator Design
### The Problem
During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted.
### Evaluator Design
Currently, every operation is expressed in the graph. we divide the evaluator process into three steps.
1. Initialize the metric state and add it into the block.
2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once.
3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices.
### Implementation
This design is shown in python API.
Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass.
```python
class Evaluator(object):
"""
Evaluator Base class.
"""
def __init__(self, name, **kwargs):
"""
Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts.
Auc need four variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program
The initialization of Evaluator should be responsible for:
create metric states and append to the main_program
"""
pass
def _update_ops(self, input, label, **kwargs)
"""
Add mini-batch evaluator caculate operators to the main_program.
Add increment operator to accumulate the metric states.
"""
def reset(self, executor, reset_program=None):
"""
Reset metric states at the begin of each pass/user specified batch number.
Execute the reset_program to reset the states.
"""
def eval(self, executor, eval_program=None):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
Execute the eval_program and return the result.
"""
return eval_result
```
...@@ -30,6 +30,10 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,10 @@ class AccuracyOp : public framework::OperatorWithKernel {
"Input (Label) of accuracy op should not be null."); "Input (Label) of accuracy op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
"Output (Accuracy) of AccuracyOp should not be null."); "Output (Accuracy) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Correct"),
"Output (Correct) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Total"),
"Output (Total) of AccuracyOp should not be null.");
auto inference_dim = ctx->GetInputDim("Out"); auto inference_dim = ctx->GetInputDim("Out");
auto label_dim = ctx->GetInputDim("Label"); auto label_dim = ctx->GetInputDim("Label");
...@@ -43,6 +47,8 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -43,6 +47,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
" the same as label."); " the same as label.");
ctx->SetOutputDim("Accuracy", {1}); ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});
ctx->SetOutputDim("Total", {1});
ctx->ShareLoD("Out", /*->*/ "Accuracy"); ctx->ShareLoD("Out", /*->*/ "Accuracy");
} }
...@@ -66,6 +72,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,6 +72,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", "Label of the training data"); AddInput("Label", "Label of the training data");
// TODO(typhoonzero): AddInput("Weight", ... // TODO(typhoonzero): AddInput("Weight", ...
AddOutput("Accuracy", "The accuracy of current batch"); AddOutput("Accuracy", "The accuracy of current batch");
AddOutput("Correct", "The correct samples count of current batch");
AddOutput("Total", "The samples count of current batch");
AddComment(R"DOC( AddComment(R"DOC(
Accuracy Operator. Accuracy Operator.
......
...@@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; ...@@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS;
template <int BlockSize> template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D, __global__ void AccuracyCudaKernel(const int N, const int D,
const int64_t* Xdata, const int64_t* Xdata,
const int64_t* labeldata, float* accuracy) { const int64_t* labeldata, int* correct_data,
float* accuracy) {
int count = 0; int count = 0;
__shared__ int total[BlockSize]; __shared__ int total[BlockSize];
...@@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, ...@@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
// reduce the count with init value 0, and output accuracy. // reduce the count with init value 0, and output accuracy.
int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
*correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N); *accuracy = static_cast<float>(result) / static_cast<float>(N);
} }
} }
...@@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
auto* inference = ctx.Input<Tensor>("Out"); auto* inference = ctx.Input<Tensor>("Out");
auto* indices = ctx.Input<Tensor>("Indices"); auto* indices = ctx.Input<Tensor>("Indices");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy"); auto* accuracy = ctx.Output<Tensor>("Accuracy");
auto* correct = ctx.Output<Tensor>("Correct");
auto* total = ctx.Output<Tensor>("Total");
// FIXME(typhoonzero): only support indices currently // FIXME(typhoonzero): only support indices currently
// if add support for output values, how to detect the data type? // if add support for output values, how to detect the data type?
const int64_t* indices_data = indices->data<int64_t>(); const int64_t* indices_data = indices->data<int64_t>();
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
int* total_data = total->mutable_data<int>(ctx.GetPlace());
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
size_t num_samples = inference->dims()[0]; int num_samples = static_cast<int>(inference->dims()[0]);
size_t infer_width = inference->dims()[1]; size_t infer_width = inference->dims()[1];
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
// cudaMemset((void**)&correct_data, 0, sizeof(float));
if (num_samples == 0) { if (num_samples == 0) {
return; return;
} }
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice);
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<< AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
num_samples, infer_width, indices_data, label_data, accuracy_data); num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data);
int d_num_samples, d_num_correct;
float d_accuracy;
cudaMemcpy(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// FIXME(typhoonzero): types of T is for infernece data. // FIXME(typhoonzero): types of T is for inference data.
// label data is always int // label data is always int64
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>, REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>); paddle::operators::AccuracyOpCUDAKernel<double>);
...@@ -29,7 +29,11 @@ class AccuracyKernel : public framework::OpKernel<T> { ...@@ -29,7 +29,11 @@ class AccuracyKernel : public framework::OpKernel<T> {
auto* indices = ctx.Input<Tensor>("Indices"); auto* indices = ctx.Input<Tensor>("Indices");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy"); auto* accuracy = ctx.Output<Tensor>("Accuracy");
auto* correct = ctx.Output<Tensor>("Correct");
auto* total = ctx.Output<Tensor>("Total");
int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
int* total_data = total->mutable_data<int>(ctx.GetPlace());
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
const int64_t* indices_data = indices->data<int64_t>(); const int64_t* indices_data = indices->data<int64_t>();
...@@ -55,7 +59,8 @@ class AccuracyKernel : public framework::OpKernel<T> { ...@@ -55,7 +59,8 @@ class AccuracyKernel : public framework::OpKernel<T> {
} }
} }
// FIXME(typhoonzero): we don't accumulate the accuracy for now. *correct_data = num_correct;
*total_data = num_samples;
*accuracy_data = *accuracy_data =
static_cast<float>(num_correct) / static_cast<float>(num_samples); static_cast<float>(num_correct) / static_cast<float>(num_samples);
} }
......
...@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker, ...@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad); elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int64_t>);
...@@ -35,7 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker, ...@@ -35,7 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker,
elementwise_div_grad, ops::ElementwiseOpGrad); elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int64_t>);
...@@ -37,8 +37,12 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, ...@@ -37,8 +37,12 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>); ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>); ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int64_t>);
...@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker, ...@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker,
elementwise_sub_grad, ops::ElementwiseOpGrad); elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int64_t>);
import paddle.v2.fluid.op as op
import numpy as np import numpy as np
from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable
import paddle.v2.fluid.core as core import paddle.v2.fluid.core as core
def avg_accumulate(accumulated_var, per_eval, num_batches, place): def _clone_var_in_block_(block, var):
t = np.array(accumulated_var.get_tensor()) assert isinstance(var, Variable)
t[0] += per_eval[0] return block.create_var(
accumulated_var.get_tensor().set([t[0] / float(num_batches)], place) name=var.name,
shape=var.shape,
dtype=var.data_type,
type=var.type,
lod_level=var.lod_level,
persistable=True)
class Evaluator(object): class Evaluator(object):
def __init__(self, """
scope, Evalutor Base class.
operator='accuracy',
input='Inference', create metric states
label='Label', add mini-batch evaluator caculate operator
output='Output', add increment operator to accumulate the metric states
place=core.CPUPlace()): """
"""
create an evaluator for evaluating the inference. def __init__(self, name, **kwargs):
NOTE: default run on CPUPlace(), running on GPUPlace doesn't improve performance much. """
init the global states
:param scope: the scope instance contains the input. """
:type scope: paddle.v2.fluid.core.scope self._states = {}
:param operator: operator name for caculating the evaluation for each mini-batch. if kwargs.has_key("main_program"):
:type operator: string self._main_program = kwargs.get("main_program")
:param input: output variable name of forward network. else:
:type input: string self._main_program = g_main_program
:param label: variable name of label
:type label: string def _update_ops(self, *args, **kwargs):
""" """
self.scope = scope append update ops to the global states
self.place = place """
self.output_name = output raise NotImplementedError()
self.num_batches = 0
# create variable to store accumulated evaluator output def reset(self, executor, reset_program=None):
eval_name = ''.join([operator, "@Eval"]) """
if scope.find_var(eval_name): Clear metric states at the begin of each pass/user specified batch
raise Exception("evaluator already exist in scope: %s" % eval_name) """
self.accumulated_var = scope.var(eval_name) if reset_program == None:
t = self.accumulated_var.get_tensor() reset_program = Program()
t.set_dims((1, )) else:
t.set([0.0], place) reset_program = program
# self.accumulated_var = block.create_var(block, name=eval_name, shape=(1,)) block = reset_program.global_block()
# self.accumulated_var.get_tensor().set([0.0]) for k, var in self._states.iteritems():
# create operator of evaluation g_var = _clone_var_in_block_(block, var)
var_map = dict() # var name -> variable zeros = block.create_var(dtype="float32", persistable=True)
var_map[input] = [input] block.append_op(
var_map[label] = [label] type="fill_constant",
var_map[output] = [output] outputs={"Out": [zeros]},
self.op = op.Operator(operator, **var_map) attrs={
"shape": g_var.shape,
def evaluate(self, ctx, accumulator=avg_accumulate): "value": .0,
self.op.run(self.scope, ctx) "data_type": 5,
per_eval = np.array(self.scope.find_var(self.output_name).get_tensor()) })
self.num_batches += 1 block.append_op(
accumulator(self.accumulated_var, per_eval, self.num_batches, type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
self.place) executor.run(reset_program, fetch_list=self._states.values())
def eval(self, executor, eval_program=None):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
"""
raise NotImplementedError()
class Accuracy(Evaluator):
"""
Accuracy need two state variable Total, Correct
"""
def __init__(self, *args, **kwargs):
super(Accuracy, self).__init__("accuracy", **kwargs)
block = self._main_program.global_block()
g_total = block.create_var(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = block.create_var(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
self._states["Total"] = g_total
self._states["Correct"] = g_correct
def _update_ops(self, input, label, k=1, **kwargs):
block = self._main_program.global_block()
topk_out = block.create_var(dtype=input.data_type)
topk_indices = block.create_var(dtype="int64")
block.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32"))
correct = block.create_var(dtype="int64", persistable=True)
total = block.create_var(dtype="int64", persistable=True)
block.append_op(
type="accuracy",
inputs={
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
},
outputs={
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total],
})
block.append_op(
type="cast",
inputs={"X": [self._states["Total"]]},
outputs={"Out": [self._states["Total"]]},
attrs={
"in_data_type": 5, # float32
"out_data_type": 2, #int32
})
block.append_op(
type="cast",
inputs={"X": [self._states["Correct"]]},
outputs={"Out": [self._states["Correct"]]},
attrs={
"in_data_type": 5,
"out_data_type": 2,
})
block.append_op(
type="elementwise_add",
inputs={"X": [self._states["Total"]],
"Y": [total]},
outputs={"Out": [self._states["Total"]]})
block.append_op(
type="elementwise_add",
inputs={"X": [self._states["Correct"]],
"Y": [correct]},
outputs={"Out": [self._states["Correct"]]})
return acc_out
def eval(self, executor, eval_program=None):
if eval_program != None:
eval_program = eval_program
else:
eval_program = Program()
block = eval_program.global_block()
eval_out = block.create_var(dtype=self._states["Total"].data_type)
e_total = _clone_var_in_block_(block, self._states["Total"])
e_correct = _clone_var_in_block_(block, self._states["Correct"])
block.append_op(
type="cast",
inputs={"X": [e_total]},
outputs={"Out": [e_total]},
attrs={
"in_data_type": 2, #int32
"out_data_type": 5, #float32
})
block.append_op(
type="cast",
inputs={"X": [e_correct]},
outputs={"Out": [e_correct]},
attrs={
"in_data_type": 2,
"out_data_type": 5,
})
block.append_op(
type="elementwise_div",
inputs={"X": e_correct,
"Y": e_total},
outputs={"Out": eval_out})
out = executor.run(eval_program, fetch_list=[eval_out])
return np.array(out[0])
def accuracy(*args, **kwargs):
cls = Accuracy(*args, **kwargs)
out = cls._update_ops(*args, **kwargs)
return cls, out
...@@ -574,7 +574,9 @@ def accuracy(input, label, k=1, **kwargs): ...@@ -574,7 +574,9 @@ def accuracy(input, label, k=1, **kwargs):
"Indices": [topk_indices]}, "Indices": [topk_indices]},
attrs={"k": k}) attrs={"k": k})
acc_out_dtype = kwargs.get("out_dtype", "float32") acc_out_dtype = kwargs.get("out_dtype", "float32")
acc_out = helper.create_tmp_variable(dtype=acc_out_dtype) acc_out = helper.create_tmp_variable(dtype="float32")
correct = helper.create_tmp_variable(dtype="int64")
total = helper.create_tmp_variable(dtype="int64")
helper.append_op( helper.append_op(
type="accuracy", type="accuracy",
inputs={ inputs={
...@@ -582,7 +584,11 @@ def accuracy(input, label, k=1, **kwargs): ...@@ -582,7 +584,11 @@ def accuracy(input, label, k=1, **kwargs):
"Indices": [topk_indices], "Indices": [topk_indices],
"Label": [label] "Label": [label]
}, },
outputs={"Accuracy": [acc_out]}) outputs={
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total],
})
return acc_out return acc_out
......
...@@ -3,6 +3,7 @@ import paddle.v2.fluid.layers as layers ...@@ -3,6 +3,7 @@ import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer import paddle.v2.fluid.optimizer as optimizer
import paddle.v2.fluid.evaluator as evaluator
from paddle.v2.fluid.framework import Program from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.executor import Executor
...@@ -54,17 +55,15 @@ cost = layers.cross_entropy( ...@@ -54,17 +55,15 @@ cost = layers.cross_entropy(
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
avg_cost = layers.mean(x=cost, main_program=main_program) avg_cost = layers.mean(x=cost, main_program=main_program)
accuracy = layers.accuracy( optimizer = optimizer.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999)
opts = optimizer.minimize(avg_cost, startup_program)
accuracy, acc_out = evaluator.accuracy(
input=predict, input=predict,
label=label, label=label,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
# optimizer = optimizer.MomentumOptimizer(learning_rate=0.1 / 128.0,
# momentum=0.9)
optimizer = optimizer.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999)
opts = optimizer.minimize(avg_cost, startup_program)
BATCH_SIZE = 50 BATCH_SIZE = 50
PASS_NUM = 3 PASS_NUM = 3
train_reader = paddle.batch( train_reader = paddle.batch(
...@@ -79,6 +78,7 @@ exe.run(startup_program, feed={}, fetch_list=[]) ...@@ -79,6 +78,7 @@ exe.run(startup_program, feed={}, fetch_list=[])
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
count = 0 count = 0
accuracy.reset(exe)
for data in train_reader(): for data in train_reader():
img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]), img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]),
data)).astype("float32") data)).astype("float32")
...@@ -93,11 +93,17 @@ for pass_id in range(PASS_NUM): ...@@ -93,11 +93,17 @@ for pass_id in range(PASS_NUM):
outs = exe.run(main_program, outs = exe.run(main_program,
feed={"pixel": tensor_img, feed={"pixel": tensor_img,
"label": tensor_y}, "label": tensor_y},
fetch_list=[avg_cost, accuracy]) fetch_list=[avg_cost, acc_out])
loss = np.array(outs[0]) loss = np.array(outs[0])
acc = np.array(outs[1]) acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)
print "pass id : ", pass_id, pass_acc
# print loss, acc
if loss < 10.0 and acc > 0.9: if loss < 10.0 and acc > 0.9:
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
exit(0) exit(0)
pass_acc = accuracy.eval(exe)
print "pass id : ", pass_id, pass_acc
exit(1) exit(1)
...@@ -18,7 +18,9 @@ class TestAccuracyOp(OpTest): ...@@ -18,7 +18,9 @@ class TestAccuracyOp(OpTest):
num_correct += 1 num_correct += 1
break break
self.outputs = { self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32") 'Accuracy': np.array([num_correct / float(n)]).astype("float32"),
'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32")
} }
def test_check_output(self): def test_check_output(self):
......
from paddle.v2.fluid.evaluator import Evaluator
from paddle.v2.fluid.op import Operator
import paddle.v2.fluid.core as core
import unittest
import op_test
import numpy as np
class TestEvaluator(unittest.TestCase):
def setup(self, scope, inputs, outputs):
def __create_var__(var_name, arr):
np_arr = np.array(arr)
scope.var(var_name)
# tensor = var.get_tensor()
# tensor.set_dims(np_arr.shape)
for var_name, arr in inputs.iteritems():
__create_var__(var_name, arr)
for var_name, arr in outputs.iteritems():
__create_var__(var_name, arr)
def test_evaluator(self):
inputs = {
'Inference': np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 1]]).T,
'Label': np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
}
outputs = {'Accuracy': np.array([0.9])}
out_name = 'Accuracy'
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
scope = core.Scope()
self.setup(scope, inputs, outputs)
evaluator = Evaluator(
scope,
operator='accuracy',
input='Inference',
label='Label',
output=out_name,
place=place)
op_test.set_input(scope, evaluator.op, inputs, place)
ctx = core.DeviceContext.create(place)
for i in range(10): # simulate 10 mini-batches
evaluator.evaluate(ctx)
actual = np.array(scope.find_var(out_name).get_tensor())
print actual
self.assertTrue(
np.allclose(
actual, outputs[out_name], atol=1e-5),
"output name: " + out_name + " has diff.")
if __name__ == '__main__':
exit(0)
unittest.main()
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册