提交 251c2fd5 编写于 作者: G gaoyuan

Update according to the code review

上级 e14272bb
......@@ -26,6 +26,8 @@ class BoxCoderOp : public framework::OperatorWithKernel {
"Input(PriorBoxVar) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(TargetBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutputBox"),
"Output(OutputBox) of BoxCoderOp should not be null.");
auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
......
......@@ -109,7 +109,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<Tensor>("OutputBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
......
......@@ -16,9 +16,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };
inline BoxCodeType GetBoxCodeType(const std::string& type) {
......@@ -33,8 +30,10 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) {
template <typename T>
class BoxCoderKernel : public framework::OpKernel<T> {
public:
void EncodeCenterSize(const Tensor& target_box, const Tensor& prior_box,
const Tensor& prior_box_var, T* output) const {
void EncodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var,
T* output) const {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
......@@ -76,8 +75,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
}
}
}
void DecodeCenterSize(const Tensor& target_box, const Tensor& prior_box,
const Tensor& prior_box_var, T* output) const {
void DecodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var,
T* output) const {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
......@@ -124,7 +125,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<Tensor>("OutputBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册