提交 a39240c3 编写于 作者: J jerrywgz

add attr variance for box coder, test=develop

上级 6928f831
...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"when code type is decode_center_size") "when code type is decode_center_size")
.SetDefault(0) .SetDefault(0)
.InEnum({0, 1}); .InEnum({0, 1});
AddAttr<std::vector<float>>(
"variance",
"(vector<float>, default {}),"
"variance of prior box with shape [4]. PriorBoxVar and variance can"
"not be provided at the same time.")
.SetDefault(std::vector<float>{});
AddOutput("OutputBox", AddOutput("OutputBox",
"(LoDTensor or Tensor) " "(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of " "When code_type is 'encode_center_size', the output tensor of "
......
...@@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -16,12 +18,11 @@ namespace paddle { ...@@ -16,12 +18,11 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data, __global__ void EncodeCenterSizeKernel(
const T* prior_box_var_data, const T* prior_box_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const T* target_box_data, const int row, const int col, const int len,
const int col, const int len, const bool normalized, const T prior_box_var_size, const float* variance,
const bool normalized, const int var_size, T* output) {
const T prior_box_var_size, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) { if (idx < row * col) {
const int row_idx = idx / col; const int row_idx = idx / col;
...@@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, ...@@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
} else if (var_size == 4) {
for (int k = 0; k < 4; ++k) {
output[idx * len + k] /= static_cast<T>(variance[k]);
}
} }
} }
} }
template <typename T> template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data, __global__ void DecodeCenterSizeKernel(
const T* prior_box_var_data, const T* prior_box_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const T* target_box_data, const int row, const int col, const int len,
const int col, const int len, const bool normalized, const T prior_box_var_size, const float* variance,
const bool normalized, const int var_size, const int axis, T* output) {
const T prior_box_var_size,
const int axis, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
int prior_box_offset = 0; int prior_box_offset = 0;
if (idx < row * col) { if (idx < row * col) {
...@@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, ...@@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
target_box_data[idx * len + 1] * target_box_data[idx * len + 1] *
prior_box_height + prior_box_height +
prior_box_center_y; prior_box_center_y;
} else if (var_size == 4) {
target_box_width =
exp(static_cast<T>(variance[2]) * target_box_data[idx * len + 2]) *
prior_box_width;
target_box_height =
exp(static_cast<T>(variance[3]) * target_box_data[idx * len + 3]) *
prior_box_height;
target_box_center_x = static_cast<T>(variance[0]) *
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y = static_cast<T>(variance[1]) *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else { } else {
target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width; target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
target_box_height = target_box_height =
...@@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox"); auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const T* prior_box_data = prior_box->data<T>(); const T* prior_box_data = prior_box->data<T>();
const T* target_box_data = target_box->data<T>(); const T* target_box_data = target_box->data<T>();
const T* prior_box_var_data = nullptr; const T* prior_box_var_data = nullptr;
auto prior_box_var_size = 0; auto prior_box_var_size = 0;
if (prior_box_var) { if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
prior_box_var_data = prior_box_var->data<T>(); prior_box_var_data = prior_box_var->data<T>();
prior_box_var_size = prior_box_var->dims().size(); prior_box_var_size = prior_box_var->dims().size();
} }
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
"Size of attribute 'variance' should be 4");
}
if (target_box->lod().size()) { if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
"Only support 1 level of LoD."); "Only support 1 level of LoD.");
} }
const int var_size = static_cast<T>(variance.size());
thrust::device_vector<float> dev_variance(variance.begin(), variance.end());
const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized"); bool normalized = context.Attr<bool>("box_normalized");
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
...@@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, prior_box_var_size, output); normalized, prior_box_var_size, dev_var_data, var_size, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, prior_box_var_size, axis, output); normalized, prior_box_var_size, dev_var_data, var_size, axis, output);
} }
} }
}; };
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void EncodeCenterSize(const framework::Tensor* target_box, void EncodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box, const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var, const framework::Tensor* prior_box_var,
const bool normalized, T* output) const { const bool normalized,
const std::vector<float> variance, T* output) const {
int64_t row = target_box->dims()[0]; int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0]; int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1]; int64_t len = prior_box->dims()[1];
...@@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
output[offset + 3] /= prior_box_var_data[prior_var_offset + 3]; output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
} else if (!(variance.empty())) {
for (int k = 0; k < 4; ++k) {
output[offset + k] /= static_cast<T>(variance[k]);
}
} }
} }
} }
...@@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
const framework::Tensor* prior_box, const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var, const framework::Tensor* prior_box_var,
const bool normalized, const int axis, const bool normalized, const int axis,
T* output) const { const std::vector<float> variance, T* output) const {
int64_t row = target_box->dims()[0]; int64_t row = target_box->dims()[0];
int64_t col = target_box->dims()[1]; int64_t col = target_box->dims()[1];
int64_t len = target_box->dims()[2]; int64_t len = target_box->dims()[2];
...@@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel<T> {
std::exp(prior_box_var_data[prior_var_offset + 3] * std::exp(prior_box_var_data[prior_var_offset + 3] *
target_box_data[offset + 3]) * target_box_data[offset + 3]) *
prior_box_height; prior_box_height;
} else if (!(variance.empty())) {
target_box_center_x = static_cast<T>(variance[0]) *
target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y = static_cast<T>(variance[1]) *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
target_box_width = std::exp(static_cast<T>(variance[2]) *
target_box_data[offset + 2]) *
prior_box_width;
target_box_height = std::exp(static_cast<T>(variance[3]) *
target_box_data[offset + 3]) *
prior_box_height;
} else { } else {
target_box_center_x = target_box_center_x =
target_box_data[offset] * prior_box_width + prior_box_center_x; target_box_data[offset] * prior_box_width + prior_box_center_x;
...@@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox"); auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const int axis = context.Attr<int>("axis"); const int axis = context.Attr<int>("axis");
if (target_box->lod().size()) { if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
"Only support 1 level of LoD."); "Only support 1 level of LoD.");
} }
if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
}
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
"Size of attribute 'variance' should be 4");
}
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized"); bool normalized = context.Attr<bool>("box_normalized");
...@@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T* output = output_box->data<T>(); T* output = output_box->data<T>();
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output); variance, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis,
output); variance, output);
} }
} }
}; };
......
...@@ -346,18 +346,104 @@ def box_coder(prior_box, ...@@ -346,18 +346,104 @@ def box_coder(prior_box,
name=None, name=None,
axis=0): axis=0):
""" """
${comment} **Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(\abs(tw / pw)) / pwv
oh = \log(\abs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args: Args:
prior_box(${prior_box_type}): ${prior_box_comment} prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
prior_box_var(${prior_box_var_type}): ${prior_box_var_comment} [M, 4] holds M boxes, each box is represented as
target_box(${target_box_type}): ${target_box_comment} [xmin, ymin, xmax, ymax], [xmin, ymin] is the
code_type(${code_type_type}): ${code_type_comment} left top coordinate of the anchor box, if the
box_normalized(${box_normalized_type}): ${box_normalized_comment} input is image feature map, they are close to
axis(${axis_type}): ${axis_comment} the origin of the coordinate system. [xmax, ymax]
is the right bottom coordinate of the anchor box.
prior_box_var(Variable|list): prior_box_var supports two types of input.
One is variable with shape [M, 4] holds M group.
The other one is list consist of 4 elements
shared by all boxes.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'.
This input also can be a 3-D Tensor with shape
[N, M, 4] when code_type is 'decode_center_size'.
Each box is represented as
[xmin, ymin, xmax, ymax]. This tensor can
contain LoD information to represent a batch
of inputs.
code_type(string): The code type used with the target box. It can be
encode_center_size or decode_center_size
box_normalized(int): Whether treat the priorbox as a noramlized box.
Set true by default.
name(string): The name of box coder.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape
[N, M, 4] and PriorBox has shape [M, 4], then PriorBox
will broadcast to [N, M, 4] for decoding. It is only valid
when code type is decode_center_size. Set 0 by default.
Returns: Returns:
output_box(${output_box_type}): ${output_box_comment} output_box(Variable): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape
[N, M, 4] representing the result of N target
boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size',
N represents the batch size and M represents
the number of deocded boxes.
Examples:
.. code-block:: python
prior_box = fluid.layers.data(name='prior_box',
shape=[512, 4],
dtype='float32',
append_batch_size=False)
target_box = fluid.layers.data(name='target_box',
shape=[512,81,4],
dtype='float32',
append_batch_size=False)
output = fluid.layers.box_coder(prior_box=prior_box,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box,
code_type="decode_center_size",
box_normalized=False,
axis=1)
""" """
helper = LayerHelper("box_coder", **locals()) helper = LayerHelper("box_coder", **locals())
...@@ -368,18 +454,22 @@ def box_coder(prior_box, ...@@ -368,18 +454,22 @@ def box_coder(prior_box,
output_box = helper.create_variable( output_box = helper.create_variable(
name=name, dtype=prior_box.dtype, persistable=False) name=name, dtype=prior_box.dtype, persistable=False)
helper.append_op( inputs = {"PriorBox": prior_box, "TargetBox": target_box}
type="box_coder", attrs = {
inputs={
"PriorBox": prior_box,
"PriorBoxVar": prior_box_var,
"TargetBox": target_box
},
attrs={
"code_type": code_type, "code_type": code_type,
"box_normalized": box_normalized, "box_normalized": box_normalized,
"axis": axis "axis": axis
}, }
if isinstance(prior_box_var, Variable):
inputs['PriorBoxVar'] = prior_box_var
elif isinstance(prior_box_var, list):
attrs['variance'] = prior_box_var
else:
raise TypeError("Input variance of box_coder must be Variable or lisz")
helper.append_op(
type="box_coder",
inputs=inputs,
attrs=attrs,
outputs={"OutputBox": output_box}) outputs={"OutputBox": output_box})
return output_box return output_box
......
...@@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase):
iou = layers.iou_similarity(x=x, y=y) iou = layers.iou_similarity(x=x, y=y)
bcoder = layers.box_coder( bcoder = layers.box_coder(
prior_box=x, prior_box=x,
prior_box_var=y, prior_box_var=[0.2, 0.3, 0.3, 0.2],
target_box=z, target_box=z,
code_type='encode_center_size') code_type='encode_center_size')
self.assertIsNotNone(iou) self.assertIsNotNone(iou)
......
...@@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest): ...@@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "box_coder" self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]] lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((10, 4)).astype('float32') prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32') prior_box_var = np.random.random((81, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32') target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize" code_type = "DecodeCenterSize"
box_normalized = False box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
...@@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest): ...@@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest):
def setUp(self): def setUp(self):
self.op_type = "box_coder" self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]] lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((6, 4)).astype('float32') prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((3, 6, 4)).astype('float32') target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize" code_type = "DecodeCenterSize"
box_normalized = False box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
...@@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest): ...@@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
def setUp(self): def setUp(self):
self.op_type = "box_coder" self.op_type = "box_coder"
lod = [[0, 1, 2, 3, 4, 5]] lod = [[0, 1, 2, 3, 4, 5]]
prior_box = np.random.random((10, 4)).astype('float32') prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.ones((10, 4)).astype('float32') prior_box_var = np.ones((81, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32') target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize" code_type = "DecodeCenterSize"
box_normalized = False box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
...@@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest): ...@@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest):
def setUp(self): def setUp(self):
self.op_type = "box_coder" self.op_type = "box_coder"
lod = [[4, 8, 8]] lod = [[10, 20, 20]]
prior_box = np.random.random((10, 4)).astype('float32') prior_box = np.random.random((20, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32') prior_box_var = np.random.random((20, 4)).astype('float32')
target_box = np.random.random((20, 4)).astype('float32') target_box = np.random.random((50, 4)).astype('float32')
code_type = "EncodeCenterSize" code_type = "EncodeCenterSize"
box_normalized = True box_normalized = True
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
...@@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest): ...@@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest):
def setUp(self): def setUp(self):
self.op_type = "box_coder" self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]] lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((5, 4)).astype('float32') prior_box = np.random.random((30, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((5, 6, 4)).astype('float32') target_box = np.random.random((30, 81, 4)).astype('float32')
code_type = "DecodeCenterSize" code_type = "DecodeCenterSize"
box_normalized = False box_normalized = False
axis = 1 axis = 1
...@@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest): ...@@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest):
self.outputs = {'OutputBox': output_box} self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithVariance(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((30, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((30, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
axis = 1
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized, axis)
self.inputs = {
'PriorBox': prior_box,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False,
'variance': prior_box_var.astype(np.float).flatten(),
'axis': axis
}
self.outputs = {'OutputBox': output_box}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册