提交 251c2fd5 编写于 作者: G gaoyuan

Update according to the code review

上级 e14272bb
...@@ -26,6 +26,8 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -26,6 +26,8 @@ class BoxCoderOp : public framework::OperatorWithKernel {
"Input(PriorBoxVar) of BoxCoderOp should not be null."); "Input(PriorBoxVar) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBox"), PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(TargetBox) of BoxCoderOp should not be null."); "Input(TargetBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutputBox"),
"Output(OutputBox) of BoxCoderOp should not be null.");
auto prior_box_dims = ctx->GetInputDim("PriorBox"); auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
......
...@@ -109,7 +109,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -109,7 +109,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto* prior_box = context.Input<framework::Tensor>("PriorBox"); auto* prior_box = context.Input<framework::Tensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<Tensor>("OutputBox"); auto* output_box = context.Output<framework::Tensor>("OutputBox");
if (target_box->lod().size()) { if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
......
...@@ -16,9 +16,6 @@ limitations under the License. */ ...@@ -16,9 +16,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 }; enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };
inline BoxCodeType GetBoxCodeType(const std::string& type) { inline BoxCodeType GetBoxCodeType(const std::string& type) {
...@@ -33,8 +30,10 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) { ...@@ -33,8 +30,10 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) {
template <typename T> template <typename T>
class BoxCoderKernel : public framework::OpKernel<T> { class BoxCoderKernel : public framework::OpKernel<T> {
public: public:
void EncodeCenterSize(const Tensor& target_box, const Tensor& prior_box, void EncodeCenterSize(const framework::Tensor& target_box,
const Tensor& prior_box_var, T* output) const { const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var,
T* output) const {
int64_t row = target_box.dims()[0]; int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0]; int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1]; int64_t len = prior_box.dims()[1];
...@@ -76,8 +75,10 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -76,8 +75,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
} }
} }
} }
void DecodeCenterSize(const Tensor& target_box, const Tensor& prior_box, void DecodeCenterSize(const framework::Tensor& target_box,
const Tensor& prior_box_var, T* output) const { const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var,
T* output) const {
int64_t row = target_box.dims()[0]; int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0]; int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1]; int64_t len = prior_box.dims()[1];
...@@ -124,7 +125,7 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -124,7 +125,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto* prior_box = context.Input<framework::Tensor>("PriorBox"); auto* prior_box = context.Input<framework::Tensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<Tensor>("OutputBox"); auto* output_box = context.Output<framework::Tensor>("OutputBox");
if (target_box->lod().size()) { if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册