提交 63309941 编写于 作者: T typhoonzero

pull develop and update

上级 28243520
...@@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input of Inference must be initialized."); "Input of Inference must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Inference must be initialized."); "Input of Label must be initialized.");
auto *inference = ctx.Input<framework::Tensor>("Inference"); auto inference_dim = ctx->GetInputDim("Inference");
auto *label = ctx.Input<framework::Tensor>("Label"); auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), PADDLE_ENFORCE_EQ(inference_dim, label_dim,
"inference and label should have same shape"); "inference and label should have same shape");
ctx.Output<framework::LoDTensor>("AUC")->Resize({1}); ctx->SetOutputDim("AUC", {1});
ctx->ShareLoD("Inference", /*->*/ "AUC");
} }
}; };
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <iostream>
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -27,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AucKernel : public framework::OpKernel { class AucKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference"); auto* inference = ctx.Input<Tensor>("Inference");
...@@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel { ...@@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel {
} }
// Create local tensor for storing the curve: TP, FN, TN, FP // Create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): use eigen op to caculate these values.
// TODO(typhoonzero): use op to caculate these values.
Tensor true_positive, false_positive, true_negative, false_negative; Tensor true_positive, false_positive, true_negative, false_negative;
true_positive.Resize({num_thresholds}); true_positive.Resize({num_thresholds});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册