提交 3624ff44 编写于 作者: J Juncheng 提交者: Li Xinqi

Dev bert accuracy with weight (#1632)

* accuracy

* accuracy_task_node add fw_buf

* fw_buf=>data_tmp


Former-commit-id: ab900061ce8f1050ec6d2ee39057addcfcd44f57
上级 ebbf15dc
......@@ -6,6 +6,7 @@ namespace oneflow {
void AccuracyCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("accuracy", false);
ProduceRegst("data_tmp", true);
for (TaskEdge* edge : out_edges()) { BindEdgeWithProducedRegst(edge, "accuracy"); }
}
......@@ -27,6 +28,7 @@ void AccuracyCompTaskNode::BuildExecGphAndRegst() {
accuracy_regst->AddLbi(accuracy_op->BnInOp2Lbi(obn));
accuracy_node->BindBnWithRegst(obn, accuracy_regst);
}
accuracy_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp"));
accuracy_node->InferBlobDescs(parallel_ctx());
}
......
#include "oneflow/core/kernel/accuracy_kernel.h"
#include "oneflow/core/ndarray/ndarray_util.h"
namespace oneflow {
template<DeviceType device_type, typename PredType, typename LabelType>
void AccuracyKernel<device_type, PredType, LabelType>::SetAccuracyInstanceNumBlob(
const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
CHECK_GE(this->op_attribute().input_bns().size(), 2);
this->CheckSameDim0ValidNum(this->op_attribute().input_bns(), BnInOp2Blob);
int64_t dim0_valid_num_sum =
BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum();
KernelUtil<device_type, PredType>::Set(
ctx.device_ctx, static_cast<PredType>(dim0_valid_num_sum),
BnInOp2Blob("accuracy_instance_num")->mut_dptr<PredType>());
const Blob* weight = BnInOp2Blob("weight");
Blob* accuracy_instance_num = BnInOp2Blob("accuracy_instance_num");
if (weight == nullptr) {
CHECK_GE(this->op_attribute().input_bns().size(), 2);
this->CheckSameDim0ValidNum(this->op_attribute().input_bns(), BnInOp2Blob);
int64_t dim0_valid_num_sum =
BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum();
KernelUtil<device_type, PredType>::Set(ctx.device_ctx,
static_cast<PredType>(dim0_valid_num_sum),
accuracy_instance_num->mut_dptr<PredType>());
} else {
Blob* weight_reduce_tmp = BnInOp2Blob("weight_reduce_tmp");
CHECK_LE(weight->shape().elem_cnt(), weight_reduce_tmp->shape().elem_cnt());
const int64_t num_instance = weight->shape().elem_cnt();
NdarrayUtil<device_type, PredType>::ReduceSum(
ctx.device_ctx, XpuVarNdarray<PredType>({1}, accuracy_instance_num->mut_dptr<PredType>()),
XpuVarNdarray<const PredType>({num_instance}, weight->dptr<PredType>()),
XpuVarNdarray<PredType>({num_instance}, weight_reduce_tmp->mut_dptr<PredType>()));
}
}
template<DeviceType device_type, typename PredType, typename LabelType>
......@@ -19,6 +32,8 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* X = BnInOp2Blob("prediction");
const Blob* label = BnInOp2Blob("label");
const Blob* weight = BnInOp2Blob("weight");
if (weight != nullptr) { CHECK_EQ(label->shape().elem_cnt(), weight->shape().elem_cnt()); }
Blob* accuracy = BnInOp2Blob("accuracy");
auto kernel_conf = this->kernel_conf();
const int32_t top_k = kernel_conf.op_attribute().op_conf().accuracy_conf().top_k();
......@@ -29,7 +44,7 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent(
AccuracyKernelUtil<device_type, PredType, LabelType>::Forward(
ctx.device_ctx, N, D, top_k, X->dptr<PredType>(), label->dptr<LabelType>(),
accuracy->mut_dptr<PredType>());
weight ? weight->dptr<PredType>() : nullptr, accuracy->mut_dptr<PredType>());
SetAccuracyInstanceNumBlob(ctx, BnInOp2Blob);
}
......@@ -48,8 +63,9 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardRecordIdInDevicePi
template<typename PredType, typename LabelType>
struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> {
static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
const PredType* XData, const LabelType* labelData, PredType* accuracyData) {
int correct = 0;
const PredType* XData, const LabelType* labelData, const PredType* weight,
PredType* accuracyData) {
PredType correct = 0;
for (int i = 0; i < N; ++i) {
auto label_i = labelData[i];
auto label_pred = XData[i * D + label_i];
......@@ -60,10 +76,9 @@ struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> {
if (++cnt > top_k) { break; }
}
}
if (cnt <= top_k) { ++correct; }
if (cnt <= top_k) { correct += weight ? weight[i] : OneVal<PredType>::value; }
}
CHECK_LE(correct, N);
*accuracyData = static_cast<PredType>(correct);
*accuracyData = correct;
}
};
......
......@@ -17,10 +17,10 @@ __global__ void AccuracySetZeroKernel(PredType* accuracy) {
template<typename PredType, typename LabelType>
__global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const int32_t top_k,
const PredType* Xdata, const LabelType* labelData,
PredType* accuracy) {
const PredType* weight, PredType* accuracy) {
typedef cub::BlockReduce<int32_t, kCudaThreadsNumPerBlock> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int32_t correct = 0;
PredType correct = 0;
for (int32_t row = blockIdx.x; row < N; row += gridDim.x) {
const LabelType label = labelData[row];
const PredType label_pred = Xdata[row * D + label];
......@@ -30,20 +30,22 @@ __global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const in
if (pred > label_pred || (pred == label_pred && col <= label)) { ++ngt; }
}
ngt = BlockReduce(temp_storage).Sum(ngt);
if (ngt <= top_k) { ++correct; }
if (ngt <= top_k) { correct += weight ? weight[row] : OneVal<PredType>::value; }
__syncthreads();
}
if (threadIdx.x == 0) { gpu_atomic_add(accuracy, static_cast<PredType>(correct)); }
if (threadIdx.x == 0) { gpu_atomic_add(accuracy, correct); }
}
} // namespace
template<typename PredType, typename LabelType>
struct AccuracyKernelUtil<DeviceType::kGPU, PredType, LabelType> {
static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
const PredType* XData, const LabelType* labelData, PredType* accuracyData) {
const PredType* XData, const LabelType* labelData, const PredType* weight,
PredType* accuracyData) {
AccuracySetZeroKernel<<<1, 1, 0, ctx->cuda_stream()>>>(accuracyData);
AccuracyComputeKernel<<<BlocksNum4ThreadsNum(N), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, accuracyData);
ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, weight,
accuracyData);
};
};
#define MAKE_ENTRY(data_type_pair, label_type_pair) \
......
......@@ -27,7 +27,8 @@ class AccuracyKernel final : public KernelIf<device_type> {
template<DeviceType device_type, typename PredType, typename LabelType>
struct AccuracyKernelUtil {
static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
const PredType* XData, const LabelType* labelData, PredType* accuracyData);
const PredType* XData, const LabelType* labelData, const PredType* weight,
PredType* accuracyData);
};
} // namespace oneflow
......
......@@ -7,6 +7,10 @@ void AccuracyOp::InitFromOpConf() {
EnrollInputBn("label", false);
EnrollOutputBn("accuracy", false);
EnrollOutputBn("accuracy_instance_num", false);
if (op_conf().accuracy_conf().has_weight()) {
EnrollInputBn("weight", false);
EnrollDataTmpBn("weight_reduce_tmp");
}
}
const PbMessage& AccuracyOp::GetCustomizedConf() const { return op_conf().accuracy_conf(); }
......@@ -34,6 +38,20 @@ void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> Get
CHECK_GE(pred_blob_desc->shape().NumAxes(), 2);
CHECK_EQ(label_blob_desc->shape(), Shape({pred_blob_desc->shape().At(0)}));
if (op_conf().accuracy_conf().has_weight()) {
const BlobDesc* weight = GetBlobDesc4BnInOp("weight");
CHECK_EQ(weight->shape(), label_blob_desc->shape());
CHECK_EQ(weight->data_type(), pred_blob_desc->data_type());
CHECK_EQ(weight->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field());
CHECK_EQ(weight->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape());
if (label_blob_desc->has_dim0_inner_shape()) {
CHECK_EQ(weight->dim0_inner_shape(), label_blob_desc->dim0_inner_shape());
}
BlobDesc* weight_reduce_tmp = GetBlobDesc4BnInOp("weight_reduce_tmp");
weight_reduce_tmp->mut_shape() = weight->shape();
weight_reduce_tmp->set_data_type(weight->data_type());
}
// accuracy
BlobDesc* accuracy_blob_desc = GetBlobDesc4BnInOp("accuracy");
*accuracy_blob_desc = *pred_blob_desc;
......
......@@ -744,6 +744,7 @@ message AccuracyOpConf {
required string label = 2;
optional int32 top_k = 3 [default = 1];
required string accuracy = 4;
optional string weight = 5;
}
message MatmulOpConf {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册