未验证 提交 5b103c24 编写于 作者: Q qingqing01 提交者: GitHub

Simplify multi_box_head API in detection.py and remove assign op. (#18310) (#18388)

* Simplify multi_box_head API in detection.py and remove assign op.
上级 08b0a7b7
...@@ -195,25 +195,31 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -195,25 +195,31 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
auto loc_loss_dims = ctx->GetInputDim("LocLoss"); auto loc_loss_dims = ctx->GetInputDim("LocLoss");
PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL, PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL,
"The shape of LocLoss is [N, Np]."); "The shape of LocLoss is [N, Np].");
PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0], if (ctx->IsRuntime()) {
"Batch size of ClsLoss and LocLoss must be the same."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( cls_loss_dims[0], loc_loss_dims[0],
cls_loss_dims[1], loc_loss_dims[1], "Batch size of ClsLoss and LocLoss must be the same.");
"Prior box number of ClsLoss and LocLoss must be the same."); PADDLE_ENFORCE_EQ(
cls_loss_dims[1], loc_loss_dims[1],
"Prior box number of ClsLoss and LocLoss must be the same.");
}
} }
PADDLE_ENFORCE_EQ( if (ctx->IsRuntime()) {
cls_loss_dims[0], idx_dims[0], PADDLE_ENFORCE_EQ(
"Batch size of ClsLoss and MatchIndices must be the same."); cls_loss_dims[0], idx_dims[0],
PADDLE_ENFORCE_EQ( "Batch size of ClsLoss and MatchIndices must be the same.");
cls_loss_dims[1], idx_dims[1], PADDLE_ENFORCE_EQ(
"Prior box number of ClsLoss and MatchIndices must be the same."); cls_loss_dims[1], idx_dims[1],
"Prior box number of ClsLoss and MatchIndices must be the same.");
PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0],
"Batch size of ClsLoss and MatchDist must be the same."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( cls_loss_dims[0], dis_dims[0],
cls_loss_dims[1], idx_dims[1], "Batch size of ClsLoss and MatchDist must be the same.");
"Prior box number of ClsLoss and MatchDist must be the same."); PADDLE_ENFORCE_EQ(
cls_loss_dims[1], idx_dims[1],
"Prior box number of ClsLoss and MatchDist must be the same.");
}
auto mining_type = auto mining_type =
GetMiningType(ctx->Attrs().Get<std::string>("mining_type")); GetMiningType(ctx->Attrs().Get<std::string>("mining_type"));
......
...@@ -1393,8 +1393,10 @@ def ssd_loss(location, ...@@ -1393,8 +1393,10 @@ def ssd_loss(location,
# 3. Mining hard examples # 3. Mining hard examples
actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2]) actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2])
actual_shape.stop_gradient = True actual_shape.stop_gradient = True
# shape=(-1, 0) is set for compile-time, the correct shape is set by
# actual_shape in runtime.
conf_loss = nn.reshape( conf_loss = nn.reshape(
x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape) x=conf_loss, shape=(-1, 0), actual_shape=actual_shape)
conf_loss.stop_gradient = True conf_loss.stop_gradient = True
neg_indices = helper.create_variable_for_type_inference(dtype='int32') neg_indices = helper.create_variable_for_type_inference(dtype='int32')
dtype = matched_indices.dtype dtype = matched_indices.dtype
...@@ -1464,7 +1466,9 @@ def ssd_loss(location, ...@@ -1464,7 +1466,9 @@ def ssd_loss(location,
# 5.3 Compute overall weighted loss. # 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number. # reshape to [N, Np], N is the batch size and Np is the prior box number.
loss = nn.reshape(x=loss, shape=(num, num_prior), actual_shape=actual_shape) # shape=(-1, 0) is set for compile-time, the correct shape is set by
# actual_shape in runtime.
loss = nn.reshape(x=loss, shape=(-1, 0), actual_shape=actual_shape)
loss = nn.reduce_sum(loss, dim=1, keep_dim=True) loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
if normalize: if normalize:
normalizer = nn.reduce_sum(target_loc_weight) normalizer = nn.reduce_sum(target_loc_weight)
...@@ -1927,13 +1931,7 @@ def multi_box_head(inputs, ...@@ -1927,13 +1931,7 @@ def multi_box_head(inputs,
stride=stride) stride=stride)
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
compile_shape = [ mbox_loc_flatten = nn.flatten(mbox_loc, axis=1)
mbox_loc.shape[0], cpt.floor_division(
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4
]
run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32"))
mbox_loc_flatten = nn.reshape(
mbox_loc, shape=compile_shape, actual_shape=run_shape)
mbox_locs.append(mbox_loc_flatten) mbox_locs.append(mbox_loc_flatten)
# get conf # get conf
...@@ -1945,16 +1943,7 @@ def multi_box_head(inputs, ...@@ -1945,16 +1943,7 @@ def multi_box_head(inputs,
padding=pad, padding=pad,
stride=stride) stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [0, -1, num_classes] conf_loc_flatten = nn.flatten(conf_loc, axis=1)
compile_shape = [
conf_loc.shape[0],
cpt.floor_division(conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3], num_classes), num_classes
]
run_shape = tensor.assign(
numpy.array([0, -1, num_classes]).astype("int32"))
conf_loc_flatten = nn.reshape(
conf_loc, shape=compile_shape, actual_shape=run_shape)
mbox_confs.append(conf_loc_flatten) mbox_confs.append(conf_loc_flatten)
if len(box_results) == 1: if len(box_results) == 1:
...@@ -1972,7 +1961,10 @@ def multi_box_head(inputs, ...@@ -1972,7 +1961,10 @@ def multi_box_head(inputs,
box = tensor.concat(reshaped_boxes) box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars) var = tensor.concat(reshaped_vars)
mbox_locs_concat = tensor.concat(mbox_locs, axis=1) mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
mbox_locs_concat = nn.reshape(mbox_locs_concat, shape=[0, -1, 4])
mbox_confs_concat = tensor.concat(mbox_confs, axis=1) mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
mbox_confs_concat = nn.reshape(
mbox_confs_concat, shape=[0, -1, num_classes])
box.stop_gradient = True box.stop_gradient = True
var.stop_gradient = True var.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册