From a39240c3b6af17b05e5a55bf8bbb199775498696 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Fri, 25 Jan 2019 07:46:48 +0000 Subject: [PATCH] add attr variance for box coder, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 7 + .../fluid/operators/detection/box_coder_op.cu | 59 +++++--- .../fluid/operators/detection/box_coder_op.h | 38 +++++- python/paddle/fluid/layers/detection.py | 126 +++++++++++++++--- python/paddle/fluid/tests/test_detection.py | 2 +- .../tests/unittests/test_box_coder_op.py | 57 ++++++-- 6 files changed, 236 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index f89f87663..fdcff62e1 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detection/box_coder_op.h" +#include namespace paddle { namespace operators { @@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "when code type is decode_center_size") .SetDefault(0) .InEnum({0, 1}); + AddAttr>( + "variance", + "(vector, default {})," + "variance of prior box with shape [4]. PriorBoxVar and variance can" + "not be provided at the same time.") + .SetDefault(std::vector{}); AddOutput("OutputBox", "(LoDTensor or Tensor) " "When code_type is 'encode_center_size', the output tensor of " diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 0b64224e1..9b7357227 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -16,12 +18,11 @@ namespace paddle { namespace operators { template -__global__ void EncodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, const int row, - const int col, const int len, - const bool normalized, - const T prior_box_var_size, T* output) { +__global__ void EncodeCenterSizeKernel( + const T* prior_box_data, const T* prior_box_var_data, + const T* target_box_data, const int row, const int col, const int len, + const bool normalized, const T prior_box_var_size, const float* variance, + const int var_size, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; @@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (var_size == 4) { + for (int k = 0; k < 4; ++k) { + output[idx * len + k] /= static_cast(variance[k]); + } } } } template -__global__ void DecodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, const int row, - const int col, const int len, - const bool normalized, - const T prior_box_var_size, - const int axis, T* output) { +__global__ void DecodeCenterSizeKernel( + const T* prior_box_data, const T* prior_box_var_data, + const T* target_box_data, const int row, const int col, const int len, + const bool normalized, const T prior_box_var_size, const float* variance, + const int var_size, const int axis, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; int prior_box_offset = 0; if (idx < row * col) { @@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; + } else if (var_size == 4) { + target_box_width = + exp(static_cast(variance[2]) * target_box_data[idx * len + 2]) * + prior_box_width; + target_box_height = + exp(static_cast(variance[3]) * target_box_data[idx * len + 3]) * + prior_box_height; + target_box_center_x = static_cast(variance[0]) * + target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = static_cast(variance[1]) * + target_box_data[idx * len + 1] * + prior_box_height + + prior_box_center_y; } else { target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width; target_box_height = @@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); - + std::vector variance = context.Attr>("variance"); const T* prior_box_data = prior_box->data(); const T* target_box_data = target_box->data(); const T* prior_box_var_data = nullptr; auto prior_box_var_size = 0; if (prior_box_var) { + PADDLE_ENFORCE(variance.empty(), + "Input 'PriorBoxVar' and attribute 'variance' should not" + "be used at the same time."); prior_box_var_data = prior_box_var->data(); prior_box_var_size = prior_box_var->dims().size(); } + if (!(variance.empty())) { + PADDLE_ENFORCE(static_cast(variance.size()) == 4, + "Size of attribute 'variance' should be 4"); + } if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } + const int var_size = static_cast(variance.size()); + thrust::device_vector dev_variance(variance.begin(), variance.end()); + const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data()); auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); int axis = context.Attr("axis"); @@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel { if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, prior_box_var_size, output); + normalized, prior_box_var_size, dev_var_data, var_size, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, prior_box_var_size, axis, output); + normalized, prior_box_var_size, dev_var_data, var_size, axis, output); } } }; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index 986869d8a..b61cff1b1 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -11,6 +11,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel { void EncodeCenterSize(const framework::Tensor* target_box, const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, - const bool normalized, T* output) const { + const bool normalized, + const std::vector variance, T* output) const { int64_t row = target_box->dims()[0]; int64_t col = prior_box->dims()[0]; int64_t len = prior_box->dims()[1]; @@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel { output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; output[offset + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (!(variance.empty())) { + for (int k = 0; k < 4; ++k) { + output[offset + k] /= static_cast(variance[k]); + } } } } @@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel { const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, const bool normalized, const int axis, - T* output) const { + const std::vector variance, T* output) const { int64_t row = target_box->dims()[0]; int64_t col = target_box->dims()[1]; int64_t len = target_box->dims()[2]; @@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel { std::exp(prior_box_var_data[prior_var_offset + 3] * target_box_data[offset + 3]) * prior_box_height; + } else if (!(variance.empty())) { + target_box_center_x = static_cast(variance[0]) * + target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = static_cast(variance[1]) * + target_box_data[offset + 1] * + prior_box_height + + prior_box_center_y; + target_box_width = std::exp(static_cast(variance[2]) * + target_box_data[offset + 2]) * + prior_box_width; + target_box_height = std::exp(static_cast(variance[3]) * + target_box_data[offset + 3]) * + prior_box_height; } else { target_box_center_x = target_box_data[offset] * prior_box_width + prior_box_center_x; @@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); + std::vector variance = context.Attr>("variance"); const int axis = context.Attr("axis"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, "Only support 1 level of LoD."); } + if (prior_box_var) { + PADDLE_ENFORCE(variance.empty(), + "Input 'PriorBoxVar' and attribute 'variance' should not" + "be used at the same time."); + } + if (!(variance.empty())) { + PADDLE_ENFORCE(static_cast(variance.size()) == 4, + "Size of attribute 'variance' should be 4"); + } auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); @@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel { T* output = output_box->data(); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, - output); + variance, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, - output); + variance, output); } } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 1eb876cfa..854b34d2a 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -346,18 +346,104 @@ def box_coder(prior_box, name=None, axis=0): """ - ${comment} + **Box Coder Layer** + + Encode/Decode the target bounding box with the priorbox information. + + The Encoding schema described below: + + .. math:: + + ox = (tx - px) / pw / pxv + + oy = (ty - py) / ph / pyv + + ow = \log(\abs(tw / pw)) / pwv + + oh = \log(\abs(th / ph)) / phv + + The Decoding schema described below: + + .. math:: + + ox = (pw * pxv * tx * + px) - tw / 2 + + oy = (ph * pyv * ty * + py) - th / 2 + + ow = \exp(pwv * tw) * pw + tw / 2 + + oh = \exp(phv * th) * ph + th / 2 + + where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, + width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote + the priorbox's (anchor) center coordinates, width and height. `pxv`, + `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, + `ow`, `oh` denote the encoded/decoded coordinates, width and height. + + During Box Decoding, two modes for broadcast are supported. Say target + box has shape [N, M, 4], and the shape of prior box can be [N, 4] or + [M, 4]. Then prior box will broadcast to target box along the + assigned axis. Args: - prior_box(${prior_box_type}): ${prior_box_comment} - prior_box_var(${prior_box_var_type}): ${prior_box_var_comment} - target_box(${target_box_type}): ${target_box_comment} - code_type(${code_type_type}): ${code_type_comment} - box_normalized(${box_normalized_type}): ${box_normalized_comment} - axis(${axis_type}): ${axis_comment} + prior_box(Variable): Box list prior_box is a 2-D Tensor with shape + [M, 4] holds M boxes, each box is represented as + [xmin, ymin, xmax, ymax], [xmin, ymin] is the + left top coordinate of the anchor box, if the + input is image feature map, they are close to + the origin of the coordinate system. [xmax, ymax] + is the right bottom coordinate of the anchor box. + prior_box_var(Variable|list): prior_box_var supports two types of input. + One is variable with shape [M, 4] holds M group. + The other one is list consist of 4 elements + shared by all boxes. + target_box(Variable): This input can be a 2-D LoDTensor with shape + [N, 4] when code_type is 'encode_center_size'. + This input also can be a 3-D Tensor with shape + [N, M, 4] when code_type is 'decode_center_size'. + Each box is represented as + [xmin, ymin, xmax, ymax]. This tensor can + contain LoD information to represent a batch + of inputs. + code_type(string): The code type used with the target box. It can be + encode_center_size or decode_center_size + box_normalized(int): Whether treat the priorbox as a noramlized box. + Set true by default. + name(string): The name of box coder. + axis(int): Which axis in PriorBox to broadcast for box decode, + for example, if axis is 0 and TargetBox has shape + [N, M, 4] and PriorBox has shape [M, 4], then PriorBox + will broadcast to [N, M, 4] for decoding. It is only valid + when code type is decode_center_size. Set 0 by default. Returns: - output_box(${output_box_type}): ${output_box_comment} + output_box(Variable): When code_type is 'encode_center_size', the + output tensor of box_coder_op with shape + [N, M, 4] representing the result of N target + boxes encoded with M Prior boxes and variances. + When code_type is 'decode_center_size', + N represents the batch size and M represents + the number of deocded boxes. + + Examples: + + .. code-block:: python + + prior_box = fluid.layers.data(name='prior_box', + shape=[512, 4], + dtype='float32', + append_batch_size=False) + target_box = fluid.layers.data(name='target_box', + shape=[512,81,4], + dtype='float32', + append_batch_size=False) + output = fluid.layers.box_coder(prior_box=prior_box, + prior_box_var=[0.1,0.1,0.2,0.2], + target_box=target_box, + code_type="decode_center_size", + box_normalized=False, + axis=1) + """ helper = LayerHelper("box_coder", **locals()) @@ -368,18 +454,22 @@ def box_coder(prior_box, output_box = helper.create_variable( name=name, dtype=prior_box.dtype, persistable=False) + inputs = {"PriorBox": prior_box, "TargetBox": target_box} + attrs = { + "code_type": code_type, + "box_normalized": box_normalized, + "axis": axis + } + if isinstance(prior_box_var, Variable): + inputs['PriorBoxVar'] = prior_box_var + elif isinstance(prior_box_var, list): + attrs['variance'] = prior_box_var + else: + raise TypeError("Input variance of box_coder must be Variable or lisz") helper.append_op( type="box_coder", - inputs={ - "PriorBox": prior_box, - "PriorBoxVar": prior_box_var, - "TargetBox": target_box - }, - attrs={ - "code_type": code_type, - "box_normalized": box_normalized, - "axis": axis - }, + inputs=inputs, + attrs=attrs, outputs={"OutputBox": output_box}) return output_box diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 2d9ed9f9c..2dbcfa31f 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase): iou = layers.iou_similarity(x=x, y=y) bcoder = layers.box_coder( prior_box=x, - prior_box_var=y, + prior_box_var=[0.2, 0.3, 0.3, 0.2], target_box=z, code_type='encode_center_size') self.assertIsNotNone(iou) diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 6f7930c92..6156268bf 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((5, 10, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') + prior_box_var = np.random.random((81, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((6, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((3, 6, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[0, 1, 2, 3, 4, 5]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.ones((10, 4)).astype('float32') - target_box = np.random.random((5, 10, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') + prior_box_var = np.ones((81, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest): def setUp(self): self.op_type = "box_coder" - lod = [[4, 8, 8]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((20, 4)).astype('float32') + lod = [[10, 20, 20]] + prior_box = np.random.random((20, 4)).astype('float32') + prior_box_var = np.random.random((20, 4)).astype('float32') + target_box = np.random.random((50, 4)).astype('float32') code_type = "EncodeCenterSize" box_normalized = True output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((5, 4)).astype('float32') + prior_box = np.random.random((30, 4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((5, 6, 4)).astype('float32') + target_box = np.random.random((30, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False axis = 1 @@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest): self.outputs = {'OutputBox': output_box} +class TestBoxCoderOpWithVariance(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((30, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((30, 81, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False + axis = 1 + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized, axis) + + self.inputs = { + 'PriorBox': prior_box, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False, + 'variance': prior_box_var.astype(np.float).flatten(), + 'axis': axis + } + self.outputs = {'OutputBox': output_box} + + if __name__ == '__main__': unittest.main() -- GitLab