From 12858baa6c31f646500d9dab26053f5a340cfd0e Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Tue, 14 Nov 2017 00:26:43 -0800 Subject: [PATCH] "relauch ci" --- paddle/operators/accuracy_op.cu | 29 ++++++++++++++++++++----- python/paddle/v2/framework/evaluator.py | 8 +++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 1776f331053..b575c682f0d 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, - const int64_t* labeldata, float* accuracy) { + const int64_t* labeldata, int* correct_data, + float* accuracy) { int count = 0; __shared__ int total[BlockSize]; @@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, // reduce the count with init value 0, and output accuracy. int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); if (threadIdx.x == 0) { + *correct_data = result; *accuracy = static_cast(result) / static_cast(N); } } @@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { auto* inference = ctx.Input("Out"); auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); // FIXME(typhoonzero): only support indices currently // if add support for output values, how to detect the data type? const int64_t* indices_data = indices->data(); const int64_t* label_data = label->data(); + + int* correct_data = correct->mutable_data(ctx.GetPlace()); + int* total_data = total->mutable_data(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - size_t num_samples = inference->dims()[0]; + int num_samples = static_cast(inference->dims()[0]); size_t infer_width = inference->dims()[1]; PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); + // cudaMemset((void**)&correct_data, 0, sizeof(float)); if (num_samples == 0) { return; } + cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice); AccuracyCudaKernel<<< 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( - num_samples, infer_width, indices_data, label_data, accuracy_data); + num_samples, infer_width, indices_data, label_data, correct_data, + accuracy_data); + + int d_num_samples, d_num_correct; + float d_accuracy; + cudaMemcpy(&d_num_correct, correct_data, sizeof(int), + cudaMemcpyDeviceToHost); + cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float), + cudaMemcpyDeviceToHost); } }; } // namespace operators } // namespace paddle -// FIXME(typhoonzero): types of T is for infernece data. -// label data is always int +// FIXME(typhoonzero): types of T is for inference data. +// label data is always int64 REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, paddle::operators::AccuracyOpCUDAKernel); diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index 89290abb830..ffff25b3461 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -43,7 +43,7 @@ class Evaluator(object): """ Clear metric states at the begin of each pass/user specified batch """ - if program == None: + if reset_program == None: reset_program = Program() else: reset_program = program @@ -147,9 +147,9 @@ class Accuracy(Evaluator): return acc_out - def eval(self, executor, program=None): - if program != None: - eval_program = program + def eval(self, executor, eval_program=None): + if eval_program != None: + eval_program = eval_program else: eval_program = Program() block = eval_program.global_block() -- GitLab