提交 63309941 编写于 作者: T typhoonzero

pull develop and update

上级 28243520
......@@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"),
"Input of Inference must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input of Inference must be initialized.");
auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *label = ctx.Input<framework::Tensor>("Label");
PADDLE_ENFORCE_EQ(inference->dims(), label->dims(),
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input of Inference must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label must be initialized.");
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(inference_dim, label_dim,
"inference and label should have same shape");
ctx.Output<framework::LoDTensor>("AUC")->Resize({1});
ctx->SetOutputDim("AUC", {1});
ctx->ShareLoD("Inference", /*->*/ "AUC");
}
};
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
......@@ -27,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class AucKernel : public framework::OpKernel {
class AucKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
......@@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel {
}
// Create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope
// TODO(typhoonzero): use op to caculate these values.
// TODO(typhoonzero): use eigen op to caculate these values.
Tensor true_positive, false_positive, true_negative, false_negative;
true_positive.Resize({num_thresholds});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册