提交 7db6b07b 编写于 作者: S scxfjiang 提交者: Jinhui Yuan

Dev add accuracy module (#1008)

* naive version of accuracy module

* add top_k_ prefix to accuracy print

* remove magic number
上级 9d0bae73
#include "oneflow/core/actor/accuracy_accumulate_compute_actor.h"
namespace oneflow {
REGISTER_ACTOR(TaskType::kAccuracyAcc, AccuracyAccCompActor);
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_ACCURACY_ACCUMULATE_COMPUTE_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_ACCURACY_ACCUMULATE_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/accumulate_compute_actor.h"
namespace oneflow {
class AccuracyAccCompActor final : public AccumulateCompActor {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyAccCompActor);
AccuracyAccCompActor() = default;
~AccuracyAccCompActor() = default;
void VirtualCompActorInit(const TaskProto& proto) override {
AccumulateCompActor::Init(proto, Global<JobDesc>::Get()->PieceNumOfPrintAccuracy(),
ColIdOrder::kAscending);
}
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACCURACY_ACCUMULATE_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/accuracy_print_compute_actor.h"
namespace oneflow {
REGISTER_ACTOR(TaskType::kAccuracyPrint, AccuracyPrintCompActor);
} // namespace oneflow
\ No newline at end of file
#ifndef ONEFLOW_CORE_ACTOR_ACCURACY_PRINT_COMPUTE_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_ACCURACY_PRINT_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/sink_compute_actor.h"
namespace oneflow {
class AccuracyPrintCompActor final : public SinkCompActor {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyPrintCompActor);
AccuracyPrintCompActor() = default;
~AccuracyPrintCompActor() = default;
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACCURACY_PRINT_COMPUTE_ACTOR_H_
......@@ -201,5 +201,6 @@ void NormalForwardCompActor::SendConstBufInitMsgToBwActor() {
REGISTER_ACTOR(TaskType::kNormalForward, NormalForwardCompActor);
REGISTER_ACTOR(TaskType::kLoss, NormalForwardCompActor);
REGISTER_ACTOR(TaskType::kAccuracy, NormalForwardCompActor);
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_ACCURACY_ACCUMULATE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_ACCURACY_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/accumulate_compute_task_node.h"
namespace oneflow {
class AccuracyAccCompTaskNode final : public AccCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyAccCompTaskNode);
AccuracyAccCompTaskNode() = default;
~AccuracyAccCompTaskNode() = default;
TaskType GetTaskType() const override { return TaskType::kAccuracyAcc; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_ACCURACY_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/accuracy_compute_task_node.h"
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/graph/logical_node.h"
namespace oneflow {
void AccuracyCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("accuracy", false);
for (TaskEdge* edge : out_edges()) { BindEdgeWithProducedRegst(edge, "accuracy"); }
}
void AccuracyCompTaskNode::ConsumeAllRegsts() {
for (TaskEdge* edge : in_edges()) { ConsumeRegst("in", edge->GetSoleRegst()); }
}
void AccuracyCompTaskNode::BuildExecGphAndRegst() {
const auto& op_vec = logical_node()->op_vec();
CHECK_EQ(op_vec.size(), 1);
std::shared_ptr<const Operator> accuracy_op = op_vec[0];
ExecNode* accuracy_node = mut_exec_gph().NewNode();
accuracy_node->mut_op() = accuracy_op;
for (const std::string& ibn : accuracy_op->input_bns()) {
accuracy_node->BindBnWithOneOfTheRegsts(ibn, GetConsumedRegst("in"));
}
std::shared_ptr<RegstDesc> accuracy_regst = GetProducedRegst("accuracy");
for (const std::string& obn : accuracy_op->output_bns()) {
accuracy_regst->AddLbi(accuracy_op->BnInOp2Lbi(obn));
accuracy_node->BindBnWithRegst(obn, accuracy_regst);
}
accuracy_node->InferBlobDescs(parallel_ctx());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_ACCURAY_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_ACCURAY_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class AccuracyCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyCompTaskNode);
AccuracyCompTaskNode() = default;
~AccuracyCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void BuildExecGphAndRegst() override;
TaskType GetTaskType() const override { return TaskType::kAccuracy; }
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_ACCURAY_COMPUTE_TASK_NODE_H_
#ifndef ONEFLOW_CORE_GRAPH_ACCURACY_PRINT_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_ACCURACY_PRINT_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/sink_compute_task_node.h"
namespace oneflow {
class AccuracyPrintCompTaskNode final : public SinkCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyPrintCompTaskNode);
AccuracyPrintCompTaskNode() = default;
~AccuracyPrintCompTaskNode() = default;
TaskType GetTaskType() const override { return TaskType::kAccuracyPrint; }
bool IsPersistence() const override { return true; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_ACCURACY_PRINT_COMPUTE_TASK_NODE_H_
\ No newline at end of file
......@@ -2,6 +2,7 @@
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/job/keyword.h"
namespace oneflow {
......@@ -12,7 +13,10 @@ LogicalGraph::LogicalGraph(bool is_train) {
if (is_train) { BuildBwStruct(&edge2ibn); }
MergeEdge();
SetNodeDataLbi();
if (is_train) { BuildLossPrintStruct(); }
if (is_train) {
BuildLossPrintStruct();
BuildAccuracyPrintStruct();
}
BuildModelStruct(is_train);
if (is_train) { ConnectFwToBw(); }
ToDotWithAutoFilePath();
......@@ -363,7 +367,7 @@ void LogicalGraph::BuildLossPrintStruct() {
Connect<LogicalNode>(loss_logical, NewEdge(), loss_acc_logical);
// Loss Print Logical
OperatorConf loss_print_op_conf;
loss_print_op_conf.set_name("loss_print_" + loss_op->op_name());
loss_print_op_conf.set_name(LossPrintPrefix + loss_op->op_name());
loss_print_op_conf.set_device_type(DeviceType::kCPU);
auto loss_print_conf = loss_print_op_conf.mutable_loss_print_conf();
......@@ -386,6 +390,40 @@ void LogicalGraph::BuildLossPrintStruct() {
});
}
void LogicalGraph::BuildAccuracyPrintStruct() {
ForEachLogicalNode<AccuracyLogicalNode>([&](AccuracyLogicalNode* accuracy_logical) {
std::shared_ptr<const Operator> accuracy_op = accuracy_logical->SoleOp();
// Accuracy Accumulate Logical
OperatorConf accuracy_acc_op_conf;
accuracy_acc_op_conf.set_name("accuracy_acc_" + accuracy_op->op_name());
accuracy_acc_op_conf.set_device_type(accuracy_op->device_type());
accuracy_acc_op_conf.mutable_accumulate_conf();
std::shared_ptr<Operator> accuracy_acc_op = ConstructOp(accuracy_acc_op_conf);
AccuracyAccLogicalNode* accuracy_acc_logical = NewNode<AccuracyAccLogicalNode>();
accuracy_acc_logical->mut_op_vec() = {accuracy_acc_op};
accuracy_acc_logical->mut_parallel_desc() = accuracy_logical->parallel_desc();
Connect<LogicalNode>(accuracy_logical, NewEdge(), accuracy_acc_logical);
// Accuracy Print Logical
OperatorConf accuracy_print_op_conf;
accuracy_print_op_conf.set_name(AccuracyPrintPrefix + accuracy_op->op_name());
accuracy_print_op_conf.set_device_type(DeviceType::kCPU);
auto accuracy_print_conf = accuracy_print_op_conf.mutable_accuracy_print_conf();
*(accuracy_print_conf->mutable_accuracy_lbi()) = accuracy_op->BnInOp2Lbi("accuracy");
accuracy_print_conf->set_top_k_print(accuracy_op->op_conf().accuracy_conf().top_k());
std::shared_ptr<Operator> accuracy_print_op = ConstructOp(accuracy_print_op_conf);
ParallelConf accuracy_print_pr_conf;
accuracy_print_pr_conf.set_policy(kDataParallel);
accuracy_print_pr_conf.add_device_name(Global<JobDesc>::Get()->MachineName4MachineId(0)
+ ":cpu:1");
AccuracyPrintLogicalNode* accuracy_print_logical = NewNode<AccuracyPrintLogicalNode>();
accuracy_print_logical->mut_op_vec() = {accuracy_print_op};
accuracy_print_logical->mut_parallel_desc().reset(new ParallelDesc(accuracy_print_pr_conf));
Connect<LogicalNode>(accuracy_acc_logical, NewEdge(), accuracy_print_logical);
});
}
void LogicalGraph::BuildModelStruct(bool is_train) {
HashMap<const LogicalNode*, NormalMdUpdtLogicalNode*> first_shared2mdupdt;
ForEachLogicalNode<ForwardLogicalNode>([&](ForwardLogicalNode* fw_logical) {
......
......@@ -56,6 +56,7 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
void MergeEdge();
void SetNodeDataLbi();
void BuildLossPrintStruct();
void BuildAccuracyPrintStruct();
void BuildModelStruct(bool is_train);
void BuildReduceStruct(LogicalNode* src, LogicalNode* dst);
void SetupNormalMdUpdtOp();
......
......@@ -14,6 +14,9 @@
#include "oneflow/core/graph/reduce_local_add_compute_task_node.h"
#include "oneflow/core/graph/reduce_global_add_compute_task_node.h"
#include "oneflow/core/graph/reduce_gather_compute_task_node.h"
#include "oneflow/core/graph/accuracy_compute_task_node.h"
#include "oneflow/core/graph/accuracy_accumulate_compute_task_node.h"
#include "oneflow/core/graph/accuracy_print_compute_task_node.h"
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
......@@ -124,6 +127,9 @@ std::vector<LogicalBlobId> ReturnPackedLbi(const LogicalNode* src, const Logical
REGISTER_FUNC_FOR_FIND_LBIS("LossAcc"
"LossPrint",
ReturnPackedLbi);
REGISTER_FUNC_FOR_FIND_LBIS("AccuracyAcc"
"AccuracyPrint",
ReturnPackedLbi);
REGISTER_FUNC_FOR_FIND_LBIS("MdDiffAcc"
"NormalMdUpdt",
ReturnPackedLbi);
......@@ -315,6 +321,9 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("RecordLoad"
REGISTER_BLD_SUB_TSK_GPH_MTHD("Loss"
"LossAcc",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("Accuracy"
"AccuracyAcc",
&TaskGraph::BldSubTskGphByOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("MdDiffAcc"
"NormalMdUpdt",
BldSubTskGphToNormalMdUpdt);
......@@ -356,6 +365,12 @@ REGISTER_BLD_BOXING_OP_CONF_MTHD("Loss"
REGISTER_BLD_BOXING_OP_CONF_MTHD("LossAcc"
"LossPrint",
&BoxingTaskNode::BldBoxingOpConfWithAddAndClone);
REGISTER_BLD_BOXING_OP_CONF_MTHD("AccuracyAcc"
"AccuracyPrint",
&BoxingTaskNode::BldBoxingOpConfWithAddAndClone);
REGISTER_BLD_BOXING_OP_CONF_MTHD("Accuracy"
"Print",
&BoxingTaskNode::BldBoxingOpConfWithAddAndClone);
REGISTER_BLD_BOXING_OP_CONF_MTHD("MdDiffAcc"
"NormalMdUpdt",
&BoxingTaskNode::BldBoxingOpConfWithAddAndClone);
......@@ -378,7 +393,10 @@ REGISTER_BLD_BOXING_OP_CONF_MTHD("NormalBackward"
OF_PP_MAKE_TUPLE_SEQ(ReduceScatter, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceLocalAdd, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceGlobalAdd, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(ReduceGather, kMdUpdtArea)
OF_PP_MAKE_TUPLE_SEQ(ReduceGather, kMdUpdtArea) \
OF_PP_MAKE_TUPLE_SEQ(Accuracy, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(AccuracyAcc, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(AccuracyPrint, kPrintArea)
#define DEFINE_VIRTUAL_METHOD(x, area_type) \
std::string x##LogicalNode::TypeName() const { return #x; } \
......
......@@ -177,10 +177,13 @@ class NormalBackwardLogicalNode final : public BackwardLogicalNode {
DECLARE_NAIVE_LOGICAL_NODE(RecordLoadLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(DecodeLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(LossLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(PrintLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(LossLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(LossAccLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(LossPrintLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(AccuracyLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(AccuracyAccLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(AccuracyPrintLogicalNode);
class NormalMdUpdtLogicalNode final : public LogicalNode {
public:
......
......@@ -342,7 +342,6 @@ std::map<TaskType, std::string> task_type2color = {
{kMdSave, "1"}, {kMdDiffAcc, "7"}, {kCopyHd, "8"},
{kCopyCommNet, "9"}, {kBoxing, "10"}, {kPrint, "1"},
{kReduceScatter, "2"}, {kReduceLocalAdd, "2"}, {kReduceGlobalAdd, "2"},
{kReduceGather, "2"},
};
{kReduceGather, "2"}, {kAccuracy, "4"}, {kAccuracyPrint, "1"},
{kAccuracyAcc, "5"}};
} // namespace oneflow
......@@ -19,6 +19,7 @@ message TrainConf {
optional float l2 = 102 [default = 0];
optional int32 staleness = 103 [default = 0]; // -1 means ASP, 0 means BSP, > 0 means SSP
optional int64 piece_num_of_print_loss = 104 [default = -1];
optional int64 piece_num_of_print_accuracy = 105 [default = -1];
}
message PredictConf {
......
......@@ -67,6 +67,10 @@ int32_t JobDesc::PieceNumOfPrintLoss() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().piece_num_of_print_loss();
}
int32_t JobDesc::PieceNumOfPrintAccuracy() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().piece_num_of_print_accuracy();
}
int64_t JobDesc::BatchSize() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().batch_size();
......@@ -108,6 +112,9 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) {
if (train_conf->piece_num_of_print_loss() == -1) {
train_conf->set_piece_num_of_print_loss(NumOfPiecesInBatch());
}
if (train_conf->piece_num_of_print_accuracy() == -1) {
train_conf->set_piece_num_of_print_accuracy(NumOfPiecesInBatch());
}
if (piece_exp == -1) { piece_exp = 19 * NumOfPiecesInBatch(); }
piece_exp = std::max(piece_exp, NumOfPiecesInBatch());
piece_exp = std::max(piece_exp, train_conf->piece_num_of_print_loss());
......
......@@ -58,6 +58,7 @@ class JobDesc final {
int64_t TotalBatchNum() const;
const InitializerConf* DefaultInitializerConf() const;
int32_t PieceNumOfPrintLoss() const;
int32_t PieceNumOfPrintAccuracy() const;
int64_t BatchSize() const;
int64_t NumOfPiecesInBatch() const;
float L1() const;
......
......@@ -6,5 +6,7 @@ namespace oneflow {
const char* kPackedBlobName = ONEFLOW_INTERNAL_PREFIX "PackedBlobName";
const char* kNullDataId = ONEFLOW_INTERNAL_PREFIX "NullDataId";
const std::string LossPrintPrefix = "loss_print_";
const std::string AccuracyPrintPrefix = "accuracy_print_";
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_KEYWORD_H_
#define ONEFLOW_CORE_JOB_KEYWORD_H_
#include <string>
namespace oneflow {
extern const char* kPackedBlobName;
extern const char* kNullDataId;
extern const std::string LossPrintPrefix;
extern const std::string AccuracyPrintPrefix;
} // namespace oneflow
......
......@@ -27,6 +27,9 @@ enum TaskType {
kReduceGlobalAdd = 18;
kReduceGather = 19;
kReduceScatter = 20;
kAccuracy = 21;
kAccuracyAcc = 22;
kAccuracyPrint = 23;
};
enum AreaType {
......
#include "oneflow/core/kernel/accuracy_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename PredType, typename LabelType>
void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* X = BnInOp2Blob("prediction");
const Blob* label = BnInOp2Blob("label");
Blob* accuracy = BnInOp2Blob("accuracy");
auto kernel_conf = this->kernel_conf();
const int32_t top_k = kernel_conf.op_attribute().op_conf().accuracy_conf().top_k();
int32_t N = BnInOp2Blob("prediction")->shape().At(0);
int32_t D = BnInOp2Blob("prediction")->shape().Count(1);
CHECK_EQ(label->shape().NumAxes(), 1);
CHECK_EQ(X->blob_desc().shape().At(0), N);
AccuracyKernelUtil<device_type, PredType, LabelType>::Forward(
ctx.device_ctx, N, D, top_k, X->dptr<PredType>(), label->dptr<LabelType>(),
accuracy->mut_dptr<PredType>());
}
template<typename PredType, typename LabelType>
struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> {
static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
const PredType* XData, const LabelType* labelData, PredType* accuracyData) {
int correct = 0;
for (int i = 0; i < N; ++i) {
auto label_i = labelData[i];
auto label_pred = XData[i * D + label_i];
int cnt = 1;
for (int j = 0; j < D; ++j) {
auto pred = XData[i * D + j];
if (pred > label_pred) {
if (++cnt > top_k) { break; }
}
}
if (cnt <= top_k) { ++correct; }
}
CHECK_LE(correct, N);
*accuracyData = static_cast<PredType>(correct);
}
};
namespace {
Kernel* CreateAccuracyKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
#define ACCURACY_KERNEL_ENTRY(device_type, pred_type_pair, label_type_pair) \
{GetHashKey(device_type, OF_PP_PAIR_SECOND(pred_type_pair), OF_PP_PAIR_SECOND(label_type_pair)), \
[]() { \
return new AccuracyKernel<device_type, OF_PP_PAIR_FIRST(pred_type_pair), \
OF_PP_PAIR_FIRST(label_type_pair)>(); \
}},
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(ACCURACY_KERNEL_ENTRY, DEVICE_TYPE_SEQ,
FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)};
return creators.at(GetHashKey(kernel_conf.op_attribute().op_conf().device_type(),
kernel_conf.accuracy_conf().prediction_type(),
kernel_conf.accuracy_conf().label_type()))();
}
} // namespace
REGISTER_KERNEL_CREATOR(OperatorConf::kAccuracyConf, CreateAccuracyKernel);
#define MAKE_ENTRY(data_type_pair, label_type_pair) \
template struct AccuracyKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(label_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)
} // namespace oneflow
\ No newline at end of file
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/kernel/kernel_util.cuh"
#include "oneflow/core/kernel/accuracy_kernel.h"
#include <cub/cub.cuh>
namespace oneflow {
namespace {
template<typename PredType>
__global__ void AccuracySetZeroKernel(PredType* accuracy) {
*accuracy = 0;
}
template<typename PredType, typename LabelType>
__global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const int32_t top_k,
const PredType* Xdata, const LabelType* labelData,
PredType* accuracy) {
typedef cub::BlockReduce<int32_t, kCudaThreadsNumPerBlock> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int32_t correct = 0;
for (int32_t row = blockIdx.x; row < N; row += gridDim.x) {
const LabelType label = labelData[row];
const PredType label_pred = Xdata[row * D + label];
int32_t ngt = 0;
for (int32_t col = threadIdx.x; col < D; col += blockDim.x) {
const PredType pred = Xdata[row * D + col];
if (pred > label_pred || (pred == label_pred && col <= label)) { ++ngt; }
}
ngt = BlockReduce(temp_storage).Sum(ngt);
if (ngt <= top_k) { ++correct; }
__syncthreads();
}
if (threadIdx.x == 0) { gpu_atomic_add(accuracy, static_cast<PredType>(correct)); }
}
} // namespace
template<typename PredType, typename LabelType>
struct AccuracyKernelUtil<DeviceType::kGPU, PredType, LabelType> {
static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
const PredType* XData, const LabelType* labelData, PredType* accuracyData) {
AccuracySetZeroKernel<<<1, 1, 0, ctx->cuda_stream()>>>(accuracyData);
AccuracyComputeKernel<<<BlocksNum4ThreadsNum(N), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, accuracyData);
};
};
#define MAKE_ENTRY(data_type_pair, label_type_pair) \
template struct AccuracyKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(label_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)
} // namespace oneflow
\ No newline at end of file
#ifndef ONEFLOW_CORE_KERNEL_ACCURACY_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_ACCURACY_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
namespace oneflow {
template<DeviceType device_type, typename PredType, typename LabelType>
class AccuracyKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyKernel);
AccuracyKernel() = default;
~AccuracyKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename PredType, typename LabelType>
struct AccuracyKernelUtil {
static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
const PredType* XData, const LabelType* labelData, PredType* accuracyData);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_ACCURACY_KERNEL_H_
\ No newline at end of file
#include "oneflow/core/kernel/opkernel_test_case.h"
#include "oneflow/core/common/switch_func.h"
namespace oneflow {
namespace test {
template<DeviceType device_type, typename PredType>
struct AccuracyTestUtil final {
#define ACCURACY_TEST_UTIL_ENTRY(func_name, T) \
AccuracyTestUtil<device_type, PredType>::template func_name<T>
DEFINE_STATIC_SWITCH_FUNC(void, Test, ACCURACY_TEST_UTIL_ENTRY,
MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
template<typename LabelType>
static void Test(OpKernelTestCase* test_case, const std::string& job_type,
const std::string& fw_or_bw) {
test_case->set_is_train(job_type == "train");
test_case->set_is_forward(fw_or_bw == "forward");
AccuracyOpConf* accuracy_conf = test_case->mut_op_conf()->mutable_accuracy_conf();
accuracy_conf->set_prediction("test/prediction");
accuracy_conf->set_label("test/label");
accuracy_conf->set_accuracy("test/accuracy");
accuracy_conf->set_top_k(3);
BlobDesc* prediction_blob_desc =
new BlobDesc(Shape({10, 5}), GetDataType<PredType>::value, false, false, 1);
BlobDesc* label_blob_desc =
new BlobDesc(Shape({10}), GetDataType<LabelType>::value, false, false, 1);
BlobDesc* accuracy_blob_desc =
new BlobDesc(Shape({1}), GetDataType<PredType>::value, false, false, 1);
test_case->template InitBlob<PredType>(
"prediction", prediction_blob_desc,
{4.26386421, 9.95010348, 9.91810292, 0.48375106, 6.64594865, 8.05952355, 6.10698666,
3.00538932, 0.85184578, 2.07455643, 3.83561549, 3.09892793, 8.03172383, 6.31505591,
8.27174327, 6.85749046, 9.17082087, 2.75073689, 2.75332767, 9.59847227, 1.73445035,
0.08238581, 3.83698503, 7.04001947, 6.93058367, 2.21650175, 4.43790294, 1.03987194,
2.50459141, 4.63530169, 3.91737537, 9.57451706, 6.42044601, 6.69970151, 8.11969361,
8.47881892, 8.08534761, 0.22607914, 3.28111424, 8.59098739, 1.83841795, 3.76625112,
9.40150949, 8.43572707, 0.56068475, 8.30401856, 9.45218381, 8.35593787, 0.17762226,
0.17188453});
// the first and last label are wrong
test_case->template InitBlob<LabelType>("label", label_blob_desc,
{3, 1, 3, 4, 3, 4, 1, 4, 1, 4});
test_case->template ForwardCheckBlob<PredType>("accuracy", accuracy_blob_desc, {0.8});
}
};
template<DeviceType device_type, typename PredType>
void AccuracyKernelTestCase(OpKernelTestCase* test_case, const std::string& label_type,
const std::string& job_type, const std::string& fw_or_bw) {
AccuracyTestUtil<device_type, PredType>::SwitchTest(SwitchCase(label_type), test_case, job_type,
fw_or_bw);
}
TEST_CPU_AND_GPU_OPKERNEL(AccuracyKernelTestCase, FLOATING_DATA_TYPE_SEQ,
OF_PP_SEQ_MAP(OF_PP_PAIR_FIRST, INT_DATA_TYPE_SEQ), (predict), (forward));
} // namespace test
} // namespace oneflow
\ No newline at end of file
#include "oneflow/core/kernel/accuracy_print_kernel.h"
#include "oneflow/core/job/keyword.h"
namespace oneflow {
template<typename T>
void AccuracyPrintKernel<T>::Forward(const KernelCtx& kernel_ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* accuracy_acc_blob = BnInOp2Blob("accuracy_acc");
T accuracy_num = accuracy_acc_blob->dptr<T>()[0];
int total_num = Global<JobDesc>::Get()->BatchSize();
float accuracy = accuracy_num / total_num;
const char* accuracy_op_name = op_conf().name().c_str() + AccuracyPrintPrefix.length();
auto kernel_conf = this->kernel_conf();
const int32_t top_k_print =
kernel_conf.op_attribute().op_conf().accuracy_print_conf().top_k_print();
LOG(INFO) << "top_" << top_k_print << "_" << accuracy_op_name << ":" << accuracy;
}
ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kAccuracyPrintConf, AccuracyPrintKernel,
FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_ACCURACY_PRINT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_ACCURACY_PRINT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<typename T>
class AccuracyPrintKernel final : public KernelIf<DeviceType::kCPU> {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyPrintKernel);
AccuracyPrintKernel() = default;
~AccuracyPrintKernel() = default;
void Forward(const KernelCtx&, std::function<Blob*(const std::string&)>) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_ACCURACY_PRINT_KERNEL_H_
\ No newline at end of file
......@@ -101,6 +101,11 @@ message DecodeOFRecordKernelConf {
required uint32 random_seed = 1;
}
message AccuracyKernelConf {
required DataType prediction_type = 1;
required DataType label_type = 2;
}
message OpAttribute {
required OperatorConf op_conf = 1;
map<string, LogicalBlobId> bn_in_op2lbi = 2;
......@@ -141,5 +146,6 @@ message KernelConf {
MaxPoolingKernelConf max_pooling_conf = 205;
NormalizationKernelConf normalization_conf = 250;
LocalResponseNormalizationKernelConf local_response_normalization_conf = 300;
AccuracyKernelConf accuracy_conf = 401;
}
}
#include "oneflow/core/kernel/loss_print_kernel.h"
#include "oneflow/core/job/keyword.h"
namespace oneflow {
......@@ -18,7 +19,7 @@ void LossPrintKernel<T>::Forward(const KernelCtx& kernel_ctx,
Global<JobDesc>::Get()->PieceSize() * Global<JobDesc>::Get()->PieceNumOfPrintLoss());
}
loss_reduced /= reduction_coefficient;
const char* loss_op_name = op_conf().name().c_str() + 11;
const char* loss_op_name = op_conf().name().c_str() + LossPrintPrefix.length();
LOG(INFO) << loss_op_name << ":" << loss_reduced;
}
......
#include "oneflow/core/operator/accuracy_op.h"
namespace oneflow {
void AccuracyOp::InitFromOpConf() {
EnrollInputBn("prediction", false);
EnrollInputBn("label", false);
EnrollOutputBn("accuracy", false);
}
const PbMessage& AccuracyOp::GetCustomizedConf() const { return op_conf().accuracy_conf(); }
void AccuracyOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
AccuracyKernelConf* conf = kernel_conf->mutable_accuracy_conf();
conf->set_prediction_type(GetBlobDesc4BnInOp("prediction")->data_type());
conf->set_label_type(GetBlobDesc4BnInOp("label")->data_type());
}
void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, size_t* buf_size,
std::function<void(OpContext*)>) const {
BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field());
CHECK(IsIntegralDataType(label_blob_desc->data_type()));
CHECK_GE(pred_blob_desc->shape().NumAxes(), 2);
CHECK_EQ(label_blob_desc->shape(), Shape({pred_blob_desc->shape().At(0)}));
// accuracy output blob
BlobDesc* accuracy_blob_desc = GetBlobDesc4BnInOp("accuracy");
accuracy_blob_desc->mut_shape() = Shape({1});
accuracy_blob_desc->set_data_type(pred_blob_desc->data_type());
accuracy_blob_desc->set_has_data_id_field(pred_blob_desc->has_data_id_field());
}
REGISTER_OP(OperatorConf::kAccuracyConf, AccuracyOp);
} // namespace oneflow
\ No newline at end of file
#ifndef ONEFLOW_CORE_OPERATOR_ACCURACY_OP_H_
#define ONEFLOW_CORE_OPERATOR_ACCURACY_OP_H_
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/logical_node.h"
namespace oneflow {
class AccuracyOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyOp);
AccuracyOp() = default;
virtual ~AccuracyOp() = default;
void InitFromOpConf() override;
LogicalNode* NewProperLogicalNode() override { return new AccuracyLogicalNode; }
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, size_t* buf_size,
std::function<void(OpContext*)> EnrollOpCtx) const override;
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_ACCURACY_OP_H_
\ No newline at end of file
#include "oneflow/core/operator/accuracy_print_op.h"
namespace oneflow {
void AccuracyPrintOp::InitFromOpConf() {
CHECK(op_conf().has_accuracy_print_conf());
EnrollInputBn("accuracy_acc", false);
}
const PbMessage& AccuracyPrintOp::GetCustomizedConf() const {
return op_conf().accuracy_print_conf();
}
REGISTER_OP(OperatorConf::kAccuracyPrintConf, AccuracyPrintOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_ACCURACY_PRINT_OP_H_
#define ONEFLOW_CORE_OPERATOR_ACCURACY_PRINT_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class AccuracyPrintOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(AccuracyPrintOp);
AccuracyPrintOp() = default;
~AccuracyPrintOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
private:
LogicalBlobId ibn2lbi(const std::string& input_bn) const override {
if (input_bn == "accuracy_acc") {
return op_conf().accuracy_print_conf().accuracy_lbi();
} else {
UNIMPLEMENTED();
}
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_ACCURACY_PRINT_OP_H_
......@@ -393,6 +393,11 @@ message LossPrintOpConf {
optional LossReductionType reduction_type = 4 [default = kSumOverN];
}
message AccuracyPrintOpConf {
required LogicalBlobId accuracy_lbi = 1;
optional int32 top_k_print = 3 [default = 1];
}
message ReduceSumOpConf {
oneof in_conf {
string in = 1; // For User
......@@ -575,6 +580,13 @@ message ReduceGatherOpConf {
required int32 in_num = 1;
};
message AccuracyOpConf {
required string prediction = 1;
required string label = 2;
optional int32 top_k = 3 [default = 1];
required string accuracy = 4;
}
message OperatorConf {
required string name = 1;
optional string model_load_dir = 2;
......@@ -627,6 +639,8 @@ message OperatorConf {
ReduceGlobalAddOpConf reduce_global_add_conf = 402;
ReduceGatherOpConf reduce_gather_conf = 403;
RecordLoadOpConf record_load_conf = 404;
AccuracyOpConf accuracy_conf=405;
AccuracyPrintOpConf accuracy_print_conf = 406;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册