From 8013328ed840ab65afbb2bff4eb1e27bc264eea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Tue, 31 Oct 2017 15:37:23 +0800 Subject: [PATCH] Refine evaluator op types (#5208) * refine evaluator op types * update * follow comments * update * fix v2 mnist case * fix v2 mnist case * update * update --- paddle/operators/accuracy_op.cc | 39 +++++++++++++------ paddle/operators/accuracy_op.cu | 24 +++++++----- paddle/operators/accuracy_op.h | 9 +++-- paddle/operators/auc_op.cc | 38 ++++++++++++------ paddle/operators/auc_op.h | 37 ++++++++---------- python/paddle/v2/framework/layers.py | 7 +++- .../v2/framework/tests/test_accuracy_op.py | 11 +++--- .../paddle/v2/framework/tests/test_auc_op.py | 16 ++++---- 8 files changed, 108 insertions(+), 73 deletions(-) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 88958e1634c..2a2a1e9cfd6 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -22,23 +22,35 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Inference"), - "Input(Inference) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input (Out) of accuracy op should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input (Indices) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), - "Input(Label) of AccuracyOp should not be null."); + "Input (Label) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), - "Output(Accuracy) of AccuracyOp should not be null."); + "Output (Accuracy) of AccuracyOp should not be null."); - auto inference_dim = ctx->GetInputDim("Inference"); + auto inference_dim = ctx->GetInputDim("Out"); auto label_dim = ctx->GetInputDim("Label"); + // Assume indices has same shape with infernece, because + // it's the output of topk. PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2."); PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1"); PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0], - "inference size must be the same as label size"); + "the inference tensor's num_rows must be" + " the same as label."); ctx->SetOutputDim("Accuracy", {1}); - ctx->ShareLoD("Inference", /*->*/ "Accuracy"); + ctx->ShareLoD("Out", /*->*/ "Accuracy"); + } + + protected: + // IndicateDataType + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Out")->type()); } }; @@ -48,7 +60,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { // TODO(typhoonzero): support both inference value and indices. - AddInput("Inference", "topk(indices) the network output"); + AddInput("Out", "topk (inferences) the network output"); + AddInput("Indices", "topk (indices) the network output"); AddInput("Label", "Label of the training data"); // TODO(typhoonzero): AddInput("Weight", ... AddOutput("Accuracy", "The accuracy of current batch"); @@ -59,7 +72,7 @@ The accuracy is: .. math:: accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples}) -Both the input `Inference` and `Label` can carry the LoD (Level of Details) +Both the input `Out` and `Label` can carry the LoD (Level of Details) information, or not. But the output only shares the LoD with input `Inference`. )DOC"); } @@ -71,6 +84,8 @@ information, or not. But the output only shares the LoD with input `Inference`. namespace ops = paddle::operators; REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - accuracy, ops::AccuracyKernel, - ops::AccuracyKernel); +// FIXME(typhoonzero): types of T is for infernece data. +// label data is always int. +REGISTER_OP_CPU_KERNEL(accuracy, + ops::AccuracyKernel, + ops::AccuracyKernel); diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index be58dfbd030..a0483f367e1 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -21,9 +21,10 @@ namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; -template -__global__ void AccuracyCudaKernel(const int N, const int D, const T* Xdata, - const T* labeldata, float* accuracy) { +template +__global__ void AccuracyCudaKernel(const int N, const int D, + const int64_t* Xdata, + const int64_t* labeldata, float* accuracy) { int count = 0; __shared__ int total[BlockSize]; @@ -52,13 +53,14 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use GPUPlace."); - auto* inference = ctx.Input("Inference"); + auto* inference = ctx.Input("Out"); + auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); // FIXME(typhoonzero): only support indices currently // if add support for output values, how to detect the data type? - const T* inference_data = inference->data(); - const T* label_data = label->data(); + const int64_t* indices_data = indices->data(); + const int64_t* label_data = label->data(); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); size_t num_samples = inference->dims()[0]; @@ -69,11 +71,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracyCudaKernel<<< + AccuracyCudaKernel<<< 1, PADDLE_CUDA_NUM_THREADS, 0, reinterpret_cast( ctx.device_context()) - .stream()>>>(num_samples, infer_width, inference_data, label_data, + .stream()>>>(num_samples, infer_width, indices_data, label_data, accuracy_data); } }; @@ -81,5 +83,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, - paddle::operators::AccuracyOpCUDAKernel); +// FIXME(typhoonzero): types of T is for infernece data. +// label data is always int +REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, + paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index 12c6b9aac88..1968b53d19a 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -38,14 +38,15 @@ template class AccuracyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Inference"); + auto* inference = ctx.Input("Out"); + auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - const T* inference_data = inference->data(); - const T* label_data = label->data(); + const int64_t* indices_data = indices->data(); + const int64_t* label_data = label->data(); size_t num_samples = inference->dims()[0]; size_t class_dim = inference->dims()[1]; @@ -60,7 +61,7 @@ class AccuracyKernel : public framework::OpKernel { for (size_t i = 0; i < num_samples; ++i) { PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0"); for (size_t j = 0; j < class_dim; ++j) { - if (inference_data[i * class_dim + j] == label_data[i]) { + if (indices_data[i * class_dim + j] == label_data[i]) { ++num_correct; break; } diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index cf3dbc5d10c..f5784922af1 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -23,18 +23,26 @@ class AucOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Inference"), - "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input of Indices must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input of Label must be initialized."); - auto inference_dim = ctx->GetInputDim("Inference"); - auto label_dim = ctx->GetInputDim("Label"); + auto inference_height = ctx->GetInputDim("Out")[0]; + auto label_height = ctx->GetInputDim("Label")[0]; - PADDLE_ENFORCE_EQ(inference_dim, label_dim, - "inference and label should have same shape"); + PADDLE_ENFORCE_EQ(inference_height, label_height, + "Out and Label should have same height."); ctx->SetOutputDim("AUC", {1}); - ctx->ShareLoD("Inference", /*->*/ "AUC"); + ctx->ShareLoD("Out", /*->*/ "AUC"); + } + + protected: + // IndicateDataType + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Out")->type()); } }; @@ -42,12 +50,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { public: AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Inference", - "A floating point tensor of arbitrary shape and whose values" - "are in the range [0, 1]."); + AddInput("Out", + "A floating point 2D tensor, values are in the range [0, 1]." + "Each row is descend sorted. This input should be the" + "output of topk." + "Typically, this tensor indicates the probability of each label"); + AddInput("Indices", + "An int 2D tensor, indicating the indices of original" + "tensor before sort. Typically, this tensor indicates which label" + "the probability stands for."); AddInput("Label", - "A tensor whose shape matches " - "Inference. Will be cast to bool."); + "A 2D int tensor indicating the label of the training data." + "The height is batch size and width is always 1."); // TODO(typhoonzero): support weight input AddOutput("AUC", "A scalar representing the " diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index be6ef29d5f6..e5ac57b038a 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -29,7 +29,7 @@ template class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Inference"); + auto* inference = ctx.Input("Out"); auto* label = ctx.Input("Label"); auto* auc = ctx.Output("AUC"); @@ -46,18 +46,11 @@ class AucKernel : public framework::OpKernel { thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; - size_t num_samples = inference->numel(); + size_t batch_size = inference->dims()[0]; + size_t inference_width = inference->dims()[1]; const T* inference_data = inference->data(); - Tensor label_casted; - label_casted.Resize(label->dims()); - bool* label_casted_data = label_casted.mutable_data(ctx.GetPlace()); - - const int* label_data = label->data(); - // cast label_data to bool - for (size_t i = 0; i < num_samples; i++) { - label_casted_data[i] = static_cast(label_data[i]); - } + const int64_t* label_data = label->data(); // Create local tensor for storing the curve: TP, FN, TN, FP // TODO(typhoonzero): use eigen op to caculate these values. @@ -68,23 +61,27 @@ class AucKernel : public framework::OpKernel { true_negative.Resize({num_thresholds}); false_positive.Resize({num_thresholds}); - int* tp_data = true_positive.mutable_data(ctx.GetPlace()); - int* fn_data = false_negative.mutable_data(ctx.GetPlace()); - int* tn_data = true_negative.mutable_data(ctx.GetPlace()); - int* fp_data = false_positive.mutable_data(ctx.GetPlace()); + int64_t* tp_data = true_positive.mutable_data(ctx.GetPlace()); + int64_t* fn_data = false_negative.mutable_data(ctx.GetPlace()); + int64_t* tn_data = true_negative.mutable_data(ctx.GetPlace()); + int64_t* fp_data = false_positive.mutable_data(ctx.GetPlace()); for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { // caculate TP, FN, TN, FP for current thresh - int tp = 0, fn = 0, tn = 0, fp = 0; - for (size_t i = 0; i < num_samples; i++) { - if (label_casted_data[i]) { - if (inference_data[i] >= (thresholds_list[idx_thresh])) { + int64_t tp = 0, fn = 0, tn = 0, fp = 0; + for (size_t i = 0; i < batch_size; i++) { + // NOTE: label_data used as bool, labels >0 will be treated as true. + if (label_data[i]) { + // use first(max) data in each row + if (inference_data[i * inference_width] >= + (thresholds_list[idx_thresh])) { tp++; } else { fn++; } } else { - if (inference_data[i] >= (thresholds_list[idx_thresh])) { + if (inference_data[i * inference_width] >= + (thresholds_list[idx_thresh])) { fp++; } else { tn++; diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 4727d139a28..6451d11e2b6 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -243,8 +243,11 @@ def accuracy(input, label, k=1, **kwargs): acc_out = helper.create_tmp_variable(dtype=acc_out_dtype) helper.append_op( type="accuracy", - inputs={"Inference": [topk_indices], - "Label": [label]}, + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, outputs={"Accuracy": [acc_out]}) return acc_out diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index f17edd44aef..6536c297e8e 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -7,13 +7,14 @@ class TestAccuracyOp(OpTest): def setUp(self): self.op_type = "accuracy" n = 8192 - infer = np.random.randint(0, 2, (n, 1)).astype("int") - label = np.random.randint(0, 2, (n, 1)).astype("int") - self.inputs = {'Inference': infer, "Label": label} + infer = np.random.random((n, 1)).astype("float32") + indices = np.random.randint(0, 2, (n, 1)) + label = np.random.randint(0, 2, (n, 1)) + self.inputs = {'Out': infer, 'Indices': indices, "Label": label} num_correct = 0 for rowid in xrange(n): - for ele in infer[rowid]: - if ele == label[rowid][0]: + for ele in indices[rowid]: + if ele == label[rowid]: num_correct += 1 break self.outputs = { diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/framework/tests/test_auc_op.py index 65f679cfccc..26ea905d880 100644 --- a/python/paddle/v2/framework/tests/test_auc_op.py +++ b/python/paddle/v2/framework/tests/test_auc_op.py @@ -6,10 +6,11 @@ from op_test import OpTest class TestAucOp(OpTest): def setUp(self): self.op_type = "auc" - pred = np.random.random((128)).astype("float32") - labels = np.random.randint(0, 2, (128, )) + pred = np.random.random((128, 2)).astype("float32") + indices = np.random.randint(0, 2, (128, 2)) + labels = np.random.randint(0, 2, (128, 1)) num_thresholds = 200 - self.inputs = {'Inference': pred, 'Label': labels} + self.inputs = {'Out': pred, 'Indices': indices, 'Label': labels} self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} # NOTE: sklearn use a different way to generate thresholds # which will cause the result differs slightly: @@ -31,12 +32,12 @@ class TestAucOp(OpTest): tp, fn, tn, fp = 0, 0, 0, 0 for i, lbl in enumerate(labels): if lbl: - if pred[i] >= thresh: + if pred[i, 0] >= thresh: tp += 1 else: fn += 1 else: - if pred[i] >= thresh: + if pred[i, 0] >= thresh: fp += 1 else: tn += 1 @@ -62,6 +63,5 @@ class TestAucOp(OpTest): self.check_output() -# TODO(typhoonzero): add this back till we fix it -#if __name__ == "__main__": -# unittest.main() +if __name__ == "__main__": + unittest.main() -- GitLab