From 4793e86b9247d1c5a7d1b1534a7d7d971a73fd79 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 6 Feb 2018 19:38:22 +0800 Subject: [PATCH] Add target_assign_op for SSD detection. --- paddle/framework/mixed_vector.h | 8 + paddle/operators/target_assign_op.cc | 172 ++++++++++++++++++ paddle/operators/target_assign_op.cu | 61 +++++++ paddle/operators/target_assign_op.h | 155 ++++++++++++++++ paddle/platform/assert.h | 26 +-- .../v2/fluid/tests/test_target_assign_op.py | 126 +++++++++++++ 6 files changed, 535 insertions(+), 13 deletions(-) create mode 100644 paddle/operators/target_assign_op.cc create mode 100644 paddle/operators/target_assign_op.cu create mode 100644 paddle/operators/target_assign_op.h create mode 100755 python/paddle/v2/fluid/tests/test_target_assign_op.py diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 85caac8dcd..422fbbac48 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -60,6 +60,14 @@ class Vector : public std::vector { T *data() { return std::vector::data(); } const T *data() const { return std::vector::data(); } + T *data(const platform::Place &place) { + if (platform::is_cpu_place(place)) { + return data(); + } else { + return cuda_data(); + } + } + /* Synchronize host vector to device vector */ void CopyToCUDA(); /* Synchronize device vector to host vector */ diff --git a/paddle/operators/target_assign_op.cc b/paddle/operators/target_assign_op.cc new file mode 100644 index 0000000000..9c7d625136 --- /dev/null +++ b/paddle/operators/target_assign_op.cc @@ -0,0 +1,172 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/target_assign_op.h" + +namespace paddle { +namespace operators { + +class TargetAssignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + // checkout inputs + PADDLE_ENFORCE(ctx->HasInput("EncodedGTBBox"), + "Input(EncodedGTBBox) of TargetAssignOp should not be null"); + PADDLE_ENFORCE(ctx->HasInput("GTScoreLabel"), + "Input(GTScoreLabel) of TargetAssignOp should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MatchIndices"), + "Input(MatchIndices) of TargetAssignOp should not be null"); + PADDLE_ENFORCE(ctx->HasInput("NegIndices"), + "Input(NegIndices) of TargetAssignOp should not be null"); + + // checkout outputs + PADDLE_ENFORCE( + ctx->HasOutput("PredBBoxLabel"), + "Output(PredBBoxLabel) of TargetAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("PredBBoxWeight"), + "Output(PredBBoxWeight) of TargetAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("PredScoreLabel"), + "Output(PredScoreLabel) of TargetAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("PredScoreWeight"), + "Output(PredScoreWeight) of TargetAssignOp should not be null."); + + auto blabel_dims = ctx->GetInputDim("EncodedGTBBox"); + auto slabel_dims = ctx->GetInputDim("GTScoreLabel"); + auto mi_dims = ctx->GetInputDim("MatchIndices"); + auto neg_dims = ctx->GetInputDim("NegIndices"); + + PADDLE_ENFORCE_EQ(blabel_dims.size(), 3UL, + "The rank of Input(EncodedGTBBox) must be 3."); + PADDLE_ENFORCE_EQ(slabel_dims.size(), 2UL, + "The rank of Input(GTScoreLabel) must be 2."); + PADDLE_ENFORCE_EQ(mi_dims.size(), 2UL, + "The rank of Input(MatchIndices) must be 2."); + PADDLE_ENFORCE_EQ(neg_dims.size(), 2UL, + "The rank of Input(NegIndices) must be 2."); + + PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0], + "The 1st dimension of Input(EncodedGTBBox) and " + "Input(GTScoreLabel) must be the same."); + PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1], + "The 2nd dimension of Input(EncodedGTBBox) and " + "Input(MatchIndices) must be the same."); + PADDLE_ENFORCE_EQ(blabel_dims[2], 4, + "The 3rd dimension of Input(EncodedGTBBox) must be 4."); + + auto n = mi_dims[0]; + auto np = mi_dims[1]; + ctx->SetOutputDim("PredBBoxLabel", {n, np, 4}); + ctx->SetOutputDim("PredBBoxWeight", {n, np, 1}); + ctx->SetOutputDim("PredScoreLabel", {n, np, 1}); + ctx->SetOutputDim("PredScoreWeight", {n, np, 1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("EncodedGTBBox")->type()), + ctx.device_context()); + } +}; + +class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("EncodedGTBBox", + "(LoDTensor), The encoded ground-truth bounding boxes with shape " + "[Ng, Np, 4], where Ng is the total number of ground-truth boxes " + "in this mini-batch, Np the number of predictions, 4 is the " + "number of coordinate in [xmin, ymin, xmax, ymax] layout."); + AddInput("GTScoreLabel", + "(LoDTensor, default LoDTensor), The input ground-truth " + "labels with shape [Ng, 1], where the Ng is the same as it in " + "the input of EncodedGTBBox."); + AddInput("MatchIndices", + "(Tensor, default LoDTensor), The input matched indices " + "with shape [N, Np], where N is the batch size, Np is the same " + "as it in the input of EncodedGTBBox. If MatchIndices[i][j] " + "is -1, the j-th prior box is not matched to any ground-truh " + "box in i-th instance."); + AddInput("NegIndices", + "(LoDTensor, default LoDTensor), The input negative example " + "indics with shape [Neg, 1], where is the total number of " + "negative example indices."); + AddAttr("background_label", + "(int, default 0), Label id for background class.") + .SetDefault(0); + AddOutput("PredBBoxLabel", + "(Tensor), The output encoded ground-truth labels " + "with shape [N, Np, 4], N is the batch size and Np, 4 is the " + "same as they in input of EncodedGTBBox. If MatchIndices[i][j] " + "is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth " + "box for background_label_id in i-th instance."); + AddOutput("PredBBoxWeight", + "(Tensor), The weight for PredBBoxLabel with the shape " + "of [N, Np, 1]"); + AddOutput("PredScoreLabel", + "(Tensor, default Tensor), The output score labels for " + "each predictions with shape [N, Np, 1]. If MatchIndices[i][j] " + "is -1, PredScoreLabel[i][j] = background_label_id."); + AddOutput("PredScoreWeight", + "(Tensor), The weight for PredScoreLabel with the shape " + "of [N, Np, 1]"); + AddComment(R"DOC( +This operator is, for given the encoded boxes between prior boxes and +ground-truth boxes and ground-truth class labels, to assign classification +and regression targets to each prior box as well as weights to each +prior box. The weights is used to specify which prior box would not contribute +to training loss. + +TODO(dang qingqing) add an example. + + )DOC"); + } +}; + +template +struct UpdateTargetLabelFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices, + const size_t* lod, const int num, const int num_prior_box, + const int background_label, int* out_label, T* out_label_wt) { + for (int i = 0; i < num; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + int id = neg_indices[j]; + out_label[i * num_prior_box + id] = background_label; + out_label_wt[i * num_prior_box + id] = static_cast(1.0); + } + } + } +}; + +template struct UpdateTargetLabelFunctor; +template struct UpdateTargetLabelFunctor; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(target_assign, ops::TargetAssignOp, + ops::TargetAssignOpMaker); +REGISTER_OP_CPU_KERNEL( + target_assign, + ops::TargetAssignKernel, + ops::TargetAssignKernel); diff --git a/paddle/operators/target_assign_op.cu b/paddle/operators/target_assign_op.cu new file mode 100644 index 0000000000..c04de86ec5 --- /dev/null +++ b/paddle/operators/target_assign_op.cu @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/target_assign_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void UpdateTargetLabelKernel(const int* neg_indices, + const size_t* lod, const int num, + const int num_prior_box, + const int background_label, + int* out_label, T* out_label_wt) { + int bidx = blockIdx.x; + int st = lod[bidx]; + int ed = lod[bidx + 1]; + + for (int i = st + threadIdx.x; i < ed; i += blockDim.x) { + int id = neg_indices[i]; + out_label[bidx * num_prior_box + id] = background_label; + out_label_wt[bidx * num_prior_box + id] = 1.; + } +} + +template +struct UpdateTargetLabelFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const int* neg_indices, const size_t* lod, const int num, + const int num_prior_box, const int background_label, + int* out_label, T* out_label_wt) { + const int block_size = 256; + const int grid_size = num; + UpdateTargetLabelKernel<<>>( + neg_indices, lod, num, num_prior_box, background_label, out_label, + out_label_wt); + } +}; + +template struct UpdateTargetLabelFunctor; +template struct UpdateTargetLabelFunctor; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + target_assign, + ops::TargetAssignKernel, + ops::TargetAssignKernel); diff --git a/paddle/operators/target_assign_op.h b/paddle/operators/target_assign_op.h new file mode 100644 index 0000000000..267bdbf1ef --- /dev/null +++ b/paddle/operators/target_assign_op.h @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct TargetAssignFunctor { + const T* gt_box_; + const int* gt_label_; + const int* match_indices_; + const size_t* lod_; + const int background_label_; + const int64_t num_; + const int64_t num_prior_box_; + + T* out_box_; + T* out_box_wt_; + int* out_label_; + T* out_label_wt_; + + TargetAssignFunctor(const T* gt_box, const int* gt_label, + const int* match_indices, const size_t* lod, + const int background_label, const int64_t num, + const int64_t np, T* out_box, T* out_box_wt, + int* out_label, T* out_label_wt) + : gt_box_(gt_box), + gt_label_(gt_label), + match_indices_(match_indices), + lod_(lod), + background_label_(background_label), + num_(num), + num_prior_box_(np), + out_box_(out_box), + out_box_wt_(out_box_wt), + out_label_(out_label), + out_label_wt_(out_label_wt) {} + + HOSTDEVICE void operator()(size_t i) const { + int row = i / num_prior_box_; + int col = i - row * num_prior_box_; + + size_t off = lod_[row]; + + int id = match_indices_[row * num_prior_box_ + col]; + T* obox = out_box_ + (row * num_prior_box_ + col) * 4; + int* olabel = out_label_ + row * num_prior_box_ + col; + T* obox_wt = out_box_wt_ + row * num_prior_box_ + col; + T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col; + + if (id > -1) { + const T* gtbox = gt_box_ + ((off + id) * num_prior_box_ + col) * 4; + + obox[0] = gtbox[0]; + obox[1] = gtbox[1]; + obox[2] = gtbox[2]; + obox[3] = gtbox[3]; + + olabel[0] = gt_label_[off + id]; + obox_wt[0] = 1.; + olabel_wt[0] = 1.; + } else { + obox[0] = 0.; + obox[1] = 0.; + obox[2] = 0.; + obox[3] = 0.; + + olabel[0] = background_label_; + obox_wt[0] = 0.; + olabel_wt[0] = 0.; + } + } +}; + +template +struct UpdateTargetLabelFunctor { + void operator()(const platform::DeviceContext& ctx, const int* neg_indices, + const size_t* lod, const int num, const int num_prior_box, + const int background_label, int* out_label, + T* out_label_wt) const; +}; + +template +class TargetAssignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* enc_gt_box = ctx.Input("EncodedGTBBox"); + auto* gt_label = ctx.Input("GTScoreLabel"); + auto* match_indices = ctx.Input("MatchIndices"); + auto* neg_indices = ctx.Input("NegIndices"); + + auto* out_box = ctx.Output("PredBBoxLabel"); + auto* out_box_wt = ctx.Output("PredBBoxWeight"); + auto* out_label = ctx.Output("PredScoreLabel"); + auto* out_label_wt = ctx.Output("PredScoreWeight"); + + PADDLE_ENFORCE_EQ(enc_gt_box->lod().size(), 1UL); + PADDLE_ENFORCE_EQ(gt_label->lod().size(), 1UL); + PADDLE_ENFORCE_EQ(neg_indices->lod().size(), 1UL); + + int background_label = ctx.Attr("background_label"); + + const T* box_data = enc_gt_box->data(); + const int* label_data = gt_label->data(); + const int* match_idx_data = match_indices->data(); + const int* neg_idx_data = neg_indices->data(); + + T* obox_data = out_box->mutable_data(ctx.GetPlace()); + T* obox_wt_data = out_box_wt->mutable_data(ctx.GetPlace()); + int* olabel_data = out_label->mutable_data(ctx.GetPlace()); + T* olabel_wt_data = out_label_wt->mutable_data(ctx.GetPlace()); + + int64_t num = match_indices->dims()[0]; + int64_t num_prior_box = match_indices->dims()[1]; + + auto gt_lod = enc_gt_box->lod().back(); + auto neg_lod = neg_indices->lod().back(); + + size_t* gt_lod_data = gt_lod.data(ctx.GetPlace()); + size_t* neg_lod_data = neg_lod.data(ctx.GetPlace()); + + TargetAssignFunctor functor(box_data, label_data, match_idx_data, + gt_lod_data, background_label, num, + num_prior_box, obox_data, obox_wt_data, + olabel_data, olabel_wt_data); + + auto& device_ctx = ctx.template device_context(); + platform::ForRange for_range(device_ctx, + num * num_prior_box); + for_range(functor); + + UpdateTargetLabelFunctor update_functor; + update_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box, + background_label, olabel_data, olabel_wt_data); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/assert.h b/paddle/platform/assert.h index d813b9529b..1f5a8f6a19 100644 --- a/paddle/platform/assert.h +++ b/paddle/platform/assert.h @@ -1,16 +1,16 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once diff --git a/python/paddle/v2/fluid/tests/test_target_assign_op.py b/python/paddle/v2/fluid/tests/test_target_assign_op.py new file mode 100755 index 0000000000..49edff5c7f --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_target_assign_op.py @@ -0,0 +1,126 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import math +import sys +import random +from op_test import OpTest + + +def gen_match_and_neg_indices(num_prior, gt_lod, neg_lod): + if len(gt_lod) != len(neg_lod): + raise AssertionError("The input arguments are illegal.") + + batch_size = len(gt_lod) - 1 + + match_indices = -1 * np.ones((batch_size, num_prior)).astype('int32') + neg_indices = np.zeros((neg_lod[-1], 1)).astype('int32') + + for n in range(batch_size): + gt_num = gt_lod[n + 1] - gt_lod[n] + ids = random.sample([i for i in range(num_prior)], gt_num) + match_indices[n, ids] = [i for i in range(gt_num)] + + ret_ids = set([i for i in range(num_prior)]) - set(ids) + s = neg_lod[n] + e = neg_lod[n + 1] + l = e - s + neg_ids = random.sample(ret_ids, l) + neg_indices[s:e, :] = np.array(neg_ids).astype('int32').reshape(l, 1) + + return match_indices, neg_indices + + +def target_assign(encoded_box, gt_label, match_indices, neg_indices, gt_lod, + neg_lod, background_label): + batch_size, num_prior = match_indices.shape + + # init target bbox + trg_box = np.zeros((batch_size, num_prior, 4)).astype('float32') + # init weight for target bbox + trg_box_wt = np.zeros((batch_size, num_prior, 1)).astype('float32') + # init target label + trg_label = np.ones((batch_size, num_prior, 1)).astype('int32') + trg_label = trg_label * background_label + # init weight for target label + trg_label_wt = np.zeros((batch_size, num_prior, 1)).astype('float32') + + for i in range(batch_size): + cur_indices = match_indices[i] + col_ids = np.where(cur_indices > -1) + col_val = cur_indices[col_ids] + + gt_start = gt_lod[i] + # target bbox + for v, c in zip(col_val + gt_start, col_ids[0].tolist()): + trg_box[i][c][:] = encoded_box[v][c][:] + + # weight for target bbox + trg_box_wt[i][col_ids] = 1.0 + + trg_label[i][col_ids] = gt_label[col_val + gt_start] + + trg_label_wt[i][col_ids] = 1.0 + # set target label weight to 1.0 for the negative samples + neg_ids = neg_indices[neg_lod[i]:neg_lod[i + 1]] + trg_label_wt[i][neg_ids] = 1.0 + + return trg_box, trg_box_wt, trg_label, trg_label_wt + + +class TestTargetAssginOp(OpTest): + def setUp(self): + self.op_type = "target_assign" + + num_prior = 120 + num_class = 21 + gt_lod = [0, 5, 11, 23] + neg_lod = [0, 4, 7, 13] + #gt_lod = [0, 2, 5] + #neg_lod = [0, 2, 4] + batch_size = len(gt_lod) - 1 + num_gt = gt_lod[-1] + background_label = 0 + + encoded_box = np.random.random((num_gt, num_prior, 4)).astype('float32') + gt_label = np.random.randint( + num_class, size=(num_gt, 1)).astype('int32') + match_indices, neg_indices = gen_match_and_neg_indices(num_prior, + gt_lod, neg_lod) + trg_box, trg_box_wt, trg_label, trg_label_wt = target_assign( + encoded_box, gt_label, match_indices, neg_indices, gt_lod, neg_lod, + background_label) + + self.inputs = { + 'EncodedGTBBox': (encoded_box, [gt_lod]), + 'GTScoreLabel': (gt_label, [gt_lod]), + 'MatchIndices': (match_indices), + 'NegIndices': (neg_indices, [neg_lod]), + } + self.attrs = {'background_label': background_label} + self.outputs = { + 'PredBBoxLabel': (trg_box), + 'PredBBoxWeight': (trg_box_wt), + 'PredScoreLabel': (trg_label), + 'PredScoreWeight': (trg_label_wt), + } + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() -- GitLab