未验证 提交 83578cfa 编写于 作者: Z zhulei 提交者: GitHub

[npu] add box coder (#36171)

* [npu] add box coder

* [npu] add box coder
上级 2b8fd704
......@@ -15,8 +15,13 @@ function(detection_library TARGET_NAME)
PARENT_SCOPE)
endfunction()
if (WITH_ASCEND_CL)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu box_coder_op_npu.cc)
else()
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
endif()
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct BoxCoderFunction {
public:
explicit BoxCoderFunction(const framework::ExecutionContext& ctx) : ctx(ctx) {
place = ctx.GetPlace();
stream = ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
}
Tensor Adds(const Tensor& x, float scalar) {
Tensor y;
y.mutable_data<T>(x.dims(), place);
const auto& runner = NpuOpRunner("Adds", {x}, {y}, {{"value", scalar}});
runner.Run(stream);
return y;
}
Tensor Muls(const Tensor& x, float scalar) {
Tensor y;
y.mutable_data<T>(x.dims(), place);
const auto& runner = NpuOpRunner("Muls", {x}, {y}, {{"value", scalar}});
runner.Run(stream);
return y;
}
Tensor Mul(const Tensor& x, const Tensor& y) {
Tensor z;
z.mutable_data<T>(x.dims(), place);
const auto& runner = NpuOpRunner("Mul", {x, y}, {z}, {});
runner.Run(stream);
return z;
}
Tensor SubWithBroadCast(const Tensor& x, const Tensor& y,
const framework::DDim& shape) {
Tensor z;
z.mutable_data<T>(shape, place);
const auto& runner = NpuOpRunner("Sub", {x, y}, {z}, {});
runner.Run(stream);
return z;
}
void DivWithBroadCastVoid(const Tensor& x, const Tensor& y,
const framework::DDim& shape, Tensor* z) {
z->mutable_data<T>(shape, place);
const auto& runner = NpuOpRunner("Div", {x, y}, {*z}, {});
runner.Run(stream);
}
Tensor DivWithBroadCast(const Tensor& x, const Tensor& y,
const framework::DDim& shape) {
Tensor z;
DivWithBroadCastVoid(x, y, shape, &z);
return z;
}
void MulWithBroadCastVoid(const Tensor& x, const Tensor& y,
const framework::DDim& shape, Tensor* z) {
z->mutable_data<T>(shape, place);
const auto& runner = NpuOpRunner("Mul", {x, y}, {*z}, {});
runner.Run(stream);
}
Tensor MulWithBroadCast(const Tensor& x, const Tensor& y,
const framework::DDim& shape) {
Tensor z;
MulWithBroadCastVoid(x, y, shape, &z);
return z;
}
void AddWithBroadCastVoid(const Tensor& x, const Tensor& y,
const framework::DDim& shape, Tensor* z) {
z->mutable_data<T>(shape, place);
const auto& runner = NpuOpRunner("AddV2", {x, y}, {*z}, {});
runner.Run(stream);
}
Tensor AddWithBroadCast(const Tensor& x, const Tensor& y,
const framework::DDim& shape) {
Tensor z;
AddWithBroadCastVoid(x, y, shape, &z);
return z;
}
Tensor Abs(const Tensor& x) {
Tensor y;
y.mutable_data<T>(x.dims(), place);
const auto& runner = NpuOpRunner("Abs", {x}, {y}, {});
runner.Run(stream);
return y;
}
Tensor Log(const Tensor& x) {
Tensor t_x_m1 = Adds(x, -1);
Tensor y;
y.mutable_data<T>(x.dims(), place);
const auto& runner = NpuOpRunner("Log1p", {t_x_m1}, {y}, {});
runner.Run(stream);
return y;
}
Tensor Exp(const Tensor& x) {
Tensor y;
y.mutable_data<T>(x.dims(), place);
const auto& runner = NpuOpRunner("Exp", {x}, {y}, {});
runner.Run(stream);
return y;
}
Tensor Dot(const Tensor& x, const Tensor& y) {
auto dim_x = x.dims();
auto dim_y = y.dims();
PADDLE_ENFORCE_EQ(
dim_x.size(), 2,
platform::errors::InvalidArgument(
"x should be a 2-dim tensor, but got %d-dim.", dim_x.size()));
PADDLE_ENFORCE_EQ(
dim_y.size(), 2,
platform::errors::InvalidArgument(
"y should be a 2-dim tensor, but got %d-dim.", dim_y.size()));
PADDLE_ENFORCE_EQ(
dim_x[1], dim_y[0],
platform::errors::InvalidArgument("Expect dim_x[1] == dim_y[0], but "
"got dim_x[1] = %d, dim_y[0] = %d.",
dim_x[1], dim_y[0]));
Tensor z;
z.mutable_data<T>({dim_x[0], dim_y[1]}, place);
const auto& runner =
NpuOpRunner("MatMul", {x, y}, {z},
{{"transpose_x1", false}, {"transpose_x2", false}});
runner.Run(stream);
return z;
}
void ConcatVoid(const std::vector<Tensor>& inputs,
const framework::DDim& shape_out, int axis, Tensor* output) {
output->mutable_data<T>(shape_out, place);
std::vector<std::string> names;
for (size_t i = 0; i < inputs.size(); i++) {
names.push_back("x" + std::to_string(i));
}
NpuOpRunner runner{
"ConcatD",
{inputs},
{*output},
{{"concat_dim", axis}, {"N", static_cast<int>(inputs.size())}}};
runner.AddInputNames(names);
runner.Run(stream);
}
Tensor Concat(const std::vector<Tensor>& inputs,
const framework::DDim& shape_out, int axis) {
Tensor output;
ConcatVoid(inputs, shape_out, axis, &output);
return output;
}
Tensor Slice(const Tensor& x, const std::vector<int>& offsets,
const std::vector<int>& size, const framework::DDim& shape) {
Tensor y;
y.mutable_data<T>(shape, place);
const auto& runner =
NpuOpRunner("SliceD", {x}, {y}, {{"offsets", offsets}, {"size", size}});
runner.Run(stream);
return y;
}
private:
platform::Place place;
aclrtStream stream;
const framework::ExecutionContext& ctx;
};
template <typename T>
void Vector2Tensor(const framework::ExecutionContext& ctx,
const std::vector<T>& vec, const framework::DDim& ddim,
Tensor* tsr) {
framework::TensorFromVector<T>(vec, ctx.device_context(), tsr);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
tsr->Resize(ddim);
}
template <typename T>
void BoxCoderEnc(const framework::ExecutionContext& ctx, const Tensor* tb,
const Tensor* pb, const Tensor* pbv, const bool norm,
const std::vector<float>& variance, Tensor* out) {
auto M = pb->dims()[0];
auto N = tb->dims()[0];
auto shape_0 = framework::make_ddim({4, 2});
Tensor m_diff;
Tensor m_aver;
std::vector<T> vec_diff = {static_cast<T>(-1), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(-1),
static_cast<T>(1), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(1)};
std::vector<T> vec_aver = {static_cast<T>(0.5), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(0.5),
static_cast<T>(0.5), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(0.5)};
Vector2Tensor<T>(ctx, vec_diff, shape_0, &m_diff);
Vector2Tensor<T>(ctx, vec_aver, shape_0, &m_aver);
BoxCoderFunction<T> F(ctx);
Tensor pb_xy = F.Adds(F.Dot(*pb, m_aver), (norm ? 0 : 0.5));
Tensor pb_wh = F.Adds(F.Dot(*pb, m_diff), (norm ? 0 : 1));
Tensor tb_xy = F.Dot(*tb, m_aver);
Tensor tb_wh = F.Adds(F.Dot(*tb, m_diff), (norm ? 0 : 1));
pb_xy.Resize({1, M, 2});
pb_wh.Resize({1, M, 2});
tb_xy.Resize({N, 1, 2});
tb_wh.Resize({N, 1, 2});
auto shape_half = framework::make_ddim({N, M, 2});
auto shape_full = framework::make_ddim({N, M, 4});
Tensor out_xy_0 = F.DivWithBroadCast(
F.SubWithBroadCast(tb_xy, pb_xy, shape_half), pb_wh, shape_half);
Tensor out_wh_0 = F.Log(F.Abs(F.DivWithBroadCast(tb_wh, pb_wh, shape_half)));
Tensor out_0 = F.Concat({out_xy_0, out_wh_0}, shape_full, 2);
if (pbv) {
F.DivWithBroadCastVoid(out_0, *pbv, shape_full, out);
} else {
Tensor t_var;
std::vector<T> vec_var(4);
for (auto i = 0; i < 4; i++) {
vec_var[i] = static_cast<T>(variance[i]);
}
Vector2Tensor(ctx, vec_var, framework::make_ddim({1, 1, 4}), &t_var);
F.DivWithBroadCastVoid(out_0, t_var, shape_full, out);
}
}
template <typename T>
void BoxCoderDec(const framework::ExecutionContext& ctx, const Tensor* tb,
const Tensor* pb, const Tensor* pbv, const bool norm,
const std::vector<float>& variance, int axis, Tensor* out) {
auto shape_0 = framework::make_ddim({4, 2});
Tensor m_diff;
Tensor m_aver;
std::vector<T> vec_diff = {static_cast<T>(-1), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(-1),
static_cast<T>(1), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(1)};
std::vector<T> vec_aver = {static_cast<T>(0.5), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(0.5),
static_cast<T>(0.5), static_cast<T>(0),
static_cast<T>(0), static_cast<T>(0.5)};
Vector2Tensor<T>(ctx, vec_diff, shape_0, &m_diff);
Vector2Tensor<T>(ctx, vec_aver, shape_0, &m_aver);
BoxCoderFunction<T> F(ctx);
Tensor pb_xy = F.Adds(F.Dot(*pb, m_aver), (norm ? 0 : 0.5));
Tensor pb_wh = F.Adds(F.Dot(*pb, m_diff), (norm ? 0 : 1));
auto pb_resize_shape = axis == 0
? framework::make_ddim({1, pb->dims()[0], 2})
: framework::make_ddim({pb->dims()[0], 1, 2});
pb_xy.Resize(pb_resize_shape);
pb_wh.Resize(pb_resize_shape);
auto tbox_slice_shape =
framework::make_ddim({tb->dims()[0], tb->dims()[1], 2});
std::vector<int> tbox_slice_size = {static_cast<int>(tb->dims()[0]),
static_cast<int>(tb->dims()[1]), 2};
Tensor tbox01 = F.Slice(*tb, {0, 0, 0}, tbox_slice_size, tbox_slice_shape);
Tensor tbox23 = F.Slice(*tb, {0, 0, 2}, tbox_slice_size, tbox_slice_shape);
Tensor tb_xy;
Tensor tb_wh;
if (pbv) {
auto pbvt_slice_shape = framework::make_ddim({pbv->dims()[0], 2});
auto pbvt_resize_shape = axis == 0
? framework::make_ddim({1, pbv->dims()[0], 2})
: framework::make_ddim({pbv->dims()[0], 1, 2});
std::vector<int> pbvt_slice_size = {static_cast<int>(pbv->dims()[0]), 2};
Tensor pbv_t01 = F.Slice(*pbv, {0, 0}, pbvt_slice_size, pbvt_slice_shape);
Tensor pbv_t23 = F.Slice(*pbv, {0, 2}, pbvt_slice_size, pbvt_slice_shape);
pbv_t01.Resize(pbvt_resize_shape);
pbv_t23.Resize(pbvt_resize_shape);
F.AddWithBroadCastVoid(
F.MulWithBroadCast(tbox01, F.Mul(pb_wh, pbv_t01), tbox_slice_shape),
pb_xy, tbox_slice_shape, &tb_xy);
F.MulWithBroadCastVoid(
F.Exp(F.MulWithBroadCast(pbv_t23, tbox23, tbox_slice_shape)), pb_wh,
tbox_slice_shape, &tb_wh);
} else if (variance.empty()) {
F.AddWithBroadCastVoid(F.MulWithBroadCast(tbox01, pb_wh, tbox_slice_shape),
pb_xy, tbox_slice_shape, &tb_xy);
F.MulWithBroadCastVoid(F.Exp(tbox23), pb_wh, tbox_slice_shape, &tb_wh);
} else {
Tensor t_var01, t_var23;
auto t_var_shape = framework::make_ddim({1, 1, 2});
std::vector<T> vec_var01 = {static_cast<T>(variance[0]),
static_cast<T>(variance[1])};
std::vector<T> vec_var23 = {static_cast<T>(variance[2]),
static_cast<T>(variance[3])};
Vector2Tensor(ctx, vec_var01, t_var_shape, &t_var01);
Vector2Tensor(ctx, vec_var23, t_var_shape, &t_var23);
F.AddWithBroadCastVoid(
F.MulWithBroadCast(tbox01,
F.MulWithBroadCast(pb_wh, t_var01, pb_resize_shape),
tbox_slice_shape),
pb_xy, tbox_slice_shape, &tb_xy);
F.MulWithBroadCastVoid(
F.Exp(F.MulWithBroadCast(t_var23, tbox23, tbox_slice_shape)), pb_wh,
tbox_slice_shape, &tb_wh);
}
Tensor obox01 =
F.AddWithBroadCast(tb_xy, F.Muls(tb_wh, -0.5), tbox_slice_shape);
Tensor obox23 =
F.Adds(F.AddWithBroadCast(tb_xy, F.Muls(tb_wh, 0.5), tbox_slice_shape),
(norm ? 0 : -1));
F.ConcatVoid({obox01, obox23}, out->dims(), 2, out);
}
template <typename T>
class BoxCoderNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* prior_box = ctx.Input<Tensor>("PriorBox");
auto* prior_box_var = ctx.Input<Tensor>("PriorBoxVar");
auto* target_box = ctx.Input<framework::LoDTensor>("TargetBox");
auto* output_box = ctx.Output<Tensor>("OutputBox");
std::vector<float> variance = ctx.Attr<std::vector<float>>("variance");
const int axis = ctx.Attr<int>("axis");
if (prior_box_var) {
PADDLE_ENFORCE_EQ(variance.empty(), true,
platform::errors::InvalidArgument(
"Input 'PriorBoxVar' and attribute 'variance'"
" of BoxCoder operator should not be used at the "
"same time."));
}
if (!(variance.empty())) {
PADDLE_ENFORCE_EQ(static_cast<int>(variance.size()), 4,
platform::errors::InvalidArgument(
"Size of attribute 'variance' in BoxCoder operator"
" should be 4. But received size is %d",
variance.size()));
}
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
platform::errors::InvalidArgument(
"Input 'TargetBox' of BoxCoder operator only"
" supports LoD with one level."));
}
auto code_type = GetBoxCodeType(ctx.Attr<std::string>("code_type"));
bool normalized = ctx.Attr<bool>("box_normalized");
if (code_type == BoxCodeType::kEncodeCenterSize) {
BoxCoderEnc<T>(ctx, target_box, prior_box, prior_box_var, normalized,
variance, output_box);
} else {
BoxCoderDec<T>(ctx, target_box, prior_box, prior_box_var, normalized,
variance, axis, output_box);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(box_coder, ops::BoxCoderNPUKernel<float>,
ops::BoxCoderNPUKernel<plat::float16>);
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import math
import paddle
from op_test import OpTest
paddle.enable_static()
np.random.seed(2021)
def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0):
pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False)
pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False)
pb_x = pb_w * 0.5 + p_box[:, 0]
pb_y = pb_h * 0.5 + p_box[:, 1]
shape = (1, p_box.shape[0]) if axis == 0 else (p_box.shape[0], 1)
pb_w = pb_w.reshape(shape)
pb_h = pb_h.reshape(shape)
pb_x = pb_x.reshape(shape)
pb_y = pb_y.reshape(shape)
if pb_v.ndim == 2:
var_shape = (1, pb_v.shape[0], pb_v.shape[1]) if axis == 0 else (
pb_v.shape[0], 1, pb_v.shape[1])
pb_v = pb_v.reshape(var_shape)
if pb_v.ndim == 1:
tb_x = pb_v[0] * t_box[:, :, 0] * pb_w + pb_x
tb_y = pb_v[1] * t_box[:, :, 1] * pb_h + pb_y
tb_w = np.exp(pb_v[2] * t_box[:, :, 2]) * pb_w
tb_h = np.exp(pb_v[3] * t_box[:, :, 3]) * pb_h
else:
tb_x = pb_v[:, :, 0] * t_box[:, :, 0] * pb_w + pb_x
tb_y = pb_v[:, :, 1] * t_box[:, :, 1] * pb_h + pb_y
tb_w = np.exp(pb_v[:, :, 2] * t_box[:, :, 2]) * pb_w
tb_h = np.exp(pb_v[:, :, 3] * t_box[:, :, 3]) * pb_h
output_box[:, :, 0] = tb_x - tb_w / 2
output_box[:, :, 1] = tb_y - tb_h / 2
output_box[:, :, 2] = tb_x + tb_w / 2 - (not norm)
output_box[:, :, 3] = tb_y + tb_h / 2 - (not norm)
def box_encoder(t_box, p_box, pb_v, output_box, norm):
pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False)
pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False)
pb_x = pb_w * 0.5 + p_box[:, 0]
pb_y = pb_h * 0.5 + p_box[:, 1]
shape = (1, p_box.shape[0])
pb_w = pb_w.reshape(shape)
pb_h = pb_h.reshape(shape)
pb_x = pb_x.reshape(shape)
pb_y = pb_y.reshape(shape)
if pb_v.ndim == 2:
pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1])
tb_x = ((t_box[:, 2] + t_box[:, 0]) / 2).reshape(t_box.shape[0], 1)
tb_y = ((t_box[:, 3] + t_box[:, 1]) / 2).reshape(t_box.shape[0], 1)
tb_w = (t_box[:, 2] - t_box[:, 0]).reshape(t_box.shape[0], 1) + (not norm)
tb_h = (t_box[:, 3] - t_box[:, 1]).reshape(t_box.shape[0], 1) + (not norm)
if pb_v.ndim == 1:
output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[0]
output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[1]
output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[2]
output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[3]
else:
output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[:, :, 0]
output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[:, :, 1]
output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[:, :, 2]
output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[:, :, 3]
def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0):
n = t_box.shape[0]
m = p_box.shape[0]
if code_type == "decode_center_size":
m = t_box.shape[1]
output_box = np.zeros((n, m, 4), dtype=np.float32)
cur_offset = 0
for i in range(len(lod)):
if (code_type == "encode_center_size"):
box_encoder(t_box[cur_offset:(cur_offset + lod[i]), :], p_box, pb_v,
output_box[cur_offset:(cur_offset + lod[i]), :, :],
norm)
elif (code_type == "decode_center_size"):
box_decoder(t_box, p_box, pb_v, output_box, norm, axis)
cur_offset += lod[i]
return output_box
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestBoxCoderOp(OpTest):
def setUp(self):
self.op_type = "box_coder"
self.set_npu()
self.init_dtype()
self.set_init_config()
self.set_inputs()
self.set_attrs()
self.set_outputs()
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def set_init_config(self):
self.M = 81
self.N = 20
self.code_type = 'decode_center_size'
self.box_normalized = False
self.lod = [[1, 1, 1, 1, 1]]
self.axis = 0
self.use_variance = False
self.without_prior_box_var = False
self.atol = 1e-5
def set_inputs(self):
self.inputs = {}
assert (self.code_type in ['decode_center_size', 'encode_center_size'])
assert (self.axis in [0, 1])
if self.code_type == 'decode_center_size':
assert (not self.use_variance or not self.without_prior_box_var)
self.prior_box = np.random.random((self.M, 4)).astype(self.dtype)
if self.use_variance:
self.prior_box_var = np.random.random(4).astype(self.dtype)
else:
if self.without_prior_box_var:
self.prior_box_var = np.ones((self.M, 4)).astype(self.dtype)
else:
self.prior_box_var = np.random.random(
(self.M, 4)).astype(self.dtype)
if self.axis == 0:
self.target_box = np.random.random(
(self.N, self.M, 4)).astype(self.dtype)
else:
self.target_box = np.random.random(
(self.M, self.N, 4)).astype(self.dtype)
self.inputs['PriorBox'] = self.prior_box
self.inputs['TargetBox'] = self.target_box
if (not self.use_variance and not self.without_prior_box_var):
self.inputs['PriorBoxVar'] = self.prior_box_var
else:
#encode_center_size
self.prior_box = np.random.random((self.M, 4)).astype(self.dtype)
if self.use_variance:
self.prior_box_var = np.random.random(4).astype(self.dtype)
else:
self.prior_box_var = np.random.random(
(self.M, 4)).astype(self.dtype)
self.target_box = np.random.random((self.N, 4)).astype(self.dtype)
self.inputs['PriorBox'] = self.prior_box
#self.inputs['PriorBoxVar'] = self.prior_box_var
self.inputs['TargetBox'] = (self.target_box, self.lod)
if (not self.use_variance):
self.inputs['PriorBoxVar'] = self.prior_box_var
def set_attrs(self):
self.attrs = {
'code_type': self.code_type,
'box_normalized': self.box_normalized
}
if self.use_variance:
self.attrs['variance'] = self.prior_box_var.astype(
np.float).flatten()
if self.axis != 0:
self.attrs['axis'] = self.axis
def set_outputs(self):
output_box = batch_box_coder(
self.prior_box, self.prior_box_var, self.target_box, self.lod[0],
self.code_type, self.box_normalized, self.axis)
self.outputs = {'OutputBox': output_box.astype(self.dtype)}
def test_check_output(self):
self.check_output_with_place(self.place, atol=self.atol)
class TestBoxCoderOpWithoutBoxVar(TestBoxCoderOp):
def set_init_config(self):
super(TestBoxCoderOpWithoutBoxVar, self).set_init_config()
self.without_prior_box_var = True
self.lod = [[0, 1, 2, 3, 4, 5]]
class TestBoxCoderOpWithLoD(TestBoxCoderOp):
def set_init_config(self):
super(TestBoxCoderOpWithLoD, self).set_init_config()
self.M = 20
self.N = 50
self.lod = [[10, 20, 20]]
self.code_type = 'encode_center_size'
self.box_normalized = True
class TestBoxCoderOpWithLoDWithVariance(TestBoxCoderOpWithLoD):
def set_init_config(self):
super(TestBoxCoderOpWithLoDWithVariance, self).set_init_config()
self.use_variance = True
class TestBoxCoderOpWithAxis(TestBoxCoderOp):
def set_init_config(self):
super(TestBoxCoderOpWithAxis, self).set_init_config()
self.axis = 1
class TestBoxCoderOpWithVariance(TestBoxCoderOp):
def set_init_config(self):
super(TestBoxCoderOpWithVariance, self).set_init_config()
self.use_variance = True
class TestBoxCoderOpFP16(TestBoxCoderOp):
def init_dtype(self):
self.dtype = np.float16
def set_init_config(self):
super(TestBoxCoderOpFP16, self).set_init_config()
self.atol = 1e-2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册