未验证 提交 9047ac68 编写于 作者: Q qingqing01 提交者: GitHub

Simplify multi_box_head API in detection.py and remove assign op. (#18310)

* Simplify multi_box_head API in detection.py and remove assign op.
上级 e42057cd
......@@ -195,13 +195,17 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
auto loc_loss_dims = ctx->GetInputDim("LocLoss");
PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL,
"The shape of LocLoss is [N, Np].");
PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0],
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
cls_loss_dims[0], loc_loss_dims[0],
"Batch size of ClsLoss and LocLoss must be the same.");
PADDLE_ENFORCE_EQ(
cls_loss_dims[1], loc_loss_dims[1],
"Prior box number of ClsLoss and LocLoss must be the same.");
}
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
cls_loss_dims[0], idx_dims[0],
"Batch size of ClsLoss and MatchIndices must be the same.");
......@@ -209,11 +213,13 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
cls_loss_dims[1], idx_dims[1],
"Prior box number of ClsLoss and MatchIndices must be the same.");
PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0],
PADDLE_ENFORCE_EQ(
cls_loss_dims[0], dis_dims[0],
"Batch size of ClsLoss and MatchDist must be the same.");
PADDLE_ENFORCE_EQ(
cls_loss_dims[1], idx_dims[1],
"Prior box number of ClsLoss and MatchDist must be the same.");
}
auto mining_type =
GetMiningType(ctx->Attrs().Get<std::string>("mining_type"));
......
......@@ -1393,8 +1393,10 @@ def ssd_loss(location,
# 3. Mining hard examples
actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2])
actual_shape.stop_gradient = True
# shape=(-1, 0) is set for compile-time, the correct shape is set by
# actual_shape in runtime.
conf_loss = nn.reshape(
x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape)
x=conf_loss, shape=(-1, 0), actual_shape=actual_shape)
conf_loss.stop_gradient = True
neg_indices = helper.create_variable_for_type_inference(dtype='int32')
dtype = matched_indices.dtype
......@@ -1464,7 +1466,9 @@ def ssd_loss(location,
# 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number.
loss = nn.reshape(x=loss, shape=(num, num_prior), actual_shape=actual_shape)
# shape=(-1, 0) is set for compile-time, the correct shape is set by
# actual_shape in runtime.
loss = nn.reshape(x=loss, shape=(-1, 0), actual_shape=actual_shape)
loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
if normalize:
normalizer = nn.reduce_sum(target_loc_weight)
......@@ -1927,13 +1931,7 @@ def multi_box_head(inputs,
stride=stride)
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
compile_shape = [
mbox_loc.shape[0], cpt.floor_division(
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4
]
run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32"))
mbox_loc_flatten = nn.reshape(
mbox_loc, shape=compile_shape, actual_shape=run_shape)
mbox_loc_flatten = nn.flatten(mbox_loc, axis=1)
mbox_locs.append(mbox_loc_flatten)
# get conf
......@@ -1945,16 +1943,7 @@ def multi_box_head(inputs,
padding=pad,
stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [0, -1, num_classes]
compile_shape = [
conf_loc.shape[0],
cpt.floor_division(conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3], num_classes), num_classes
]
run_shape = tensor.assign(
numpy.array([0, -1, num_classes]).astype("int32"))
conf_loc_flatten = nn.reshape(
conf_loc, shape=compile_shape, actual_shape=run_shape)
conf_loc_flatten = nn.flatten(conf_loc, axis=1)
mbox_confs.append(conf_loc_flatten)
if len(box_results) == 1:
......@@ -1972,7 +1961,10 @@ def multi_box_head(inputs,
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
mbox_locs_concat = nn.reshape(mbox_locs_concat, shape=[0, -1, 4])
mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
mbox_confs_concat = nn.reshape(
mbox_confs_concat, shape=[0, -1, num_classes])
box.stop_gradient = True
var.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册