From 251c2fd50a787b474e49db7f7be9aab27fcd3ccb Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Fri, 2 Feb 2018 13:35:00 +0800 Subject: [PATCH] Update according to the code review --- paddle/operators/box_coder_op.cc | 2 ++ paddle/operators/box_coder_op.cu | 2 +- paddle/operators/box_coder_op.h | 17 +++++++++-------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/paddle/operators/box_coder_op.cc b/paddle/operators/box_coder_op.cc index 41123f9b6e..3836cef96d 100644 --- a/paddle/operators/box_coder_op.cc +++ b/paddle/operators/box_coder_op.cc @@ -26,6 +26,8 @@ class BoxCoderOp : public framework::OperatorWithKernel { "Input(PriorBoxVar) of BoxCoderOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("PriorBox"), "Input(TargetBox) of BoxCoderOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("OutputBox"), + "Output(OutputBox) of BoxCoderOp should not be null."); auto prior_box_dims = ctx->GetInputDim("PriorBox"); auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); diff --git a/paddle/operators/box_coder_op.cu b/paddle/operators/box_coder_op.cu index 883cc54305..98bd93457f 100644 --- a/paddle/operators/box_coder_op.cu +++ b/paddle/operators/box_coder_op.cu @@ -109,7 +109,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel { auto* prior_box = context.Input("PriorBox"); auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); - auto* output_box = context.Output("OutputBox"); + auto* output_box = context.Output("OutputBox"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, diff --git a/paddle/operators/box_coder_op.h b/paddle/operators/box_coder_op.h index d1c9a40459..086251f6e0 100644 --- a/paddle/operators/box_coder_op.h +++ b/paddle/operators/box_coder_op.h @@ -16,9 +16,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 }; inline BoxCodeType GetBoxCodeType(const std::string& type) { @@ -33,8 +30,10 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) { template class BoxCoderKernel : public framework::OpKernel { public: - void EncodeCenterSize(const Tensor& target_box, const Tensor& prior_box, - const Tensor& prior_box_var, T* output) const { + void EncodeCenterSize(const framework::Tensor& target_box, + const framework::Tensor& prior_box, + const framework::Tensor& prior_box_var, + T* output) const { int64_t row = target_box.dims()[0]; int64_t col = prior_box.dims()[0]; int64_t len = prior_box.dims()[1]; @@ -76,8 +75,10 @@ class BoxCoderKernel : public framework::OpKernel { } } } - void DecodeCenterSize(const Tensor& target_box, const Tensor& prior_box, - const Tensor& prior_box_var, T* output) const { + void DecodeCenterSize(const framework::Tensor& target_box, + const framework::Tensor& prior_box, + const framework::Tensor& prior_box_var, + T* output) const { int64_t row = target_box.dims()[0]; int64_t col = prior_box.dims()[0]; int64_t len = prior_box.dims()[1]; @@ -124,7 +125,7 @@ class BoxCoderKernel : public framework::OpKernel { auto* prior_box = context.Input("PriorBox"); auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); - auto* output_box = context.Output("OutputBox"); + auto* output_box = context.Output("OutputBox"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, -- GitLab