提交 8bb0d0b3 编写于 作者: X Xinqi

Merge branch 'faster_rcnn' of https://github.com/Oneflow-Inc/oneflow into faster_rcnn


Former-commit-id: 1de744e248fc3874f7fb1fc996763c2d302a33a1
#ifndef ONEFLOW_CORE_KERNEL_ROI_POOLING_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_ROI_POOLING_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class RoIPoolingKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(RoIPoolingKernel);
RoIPoolingKernel() = default;
~RoIPoolingKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void BackwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_ROI_POOLING_KERNEL_H_
......@@ -325,12 +325,12 @@ message PolynomialDecayConf {
}
message CosineDecayConf {
required int64 decay_batches = 1;
required int64 decay_batches = 1;
optional double alpha = 2 [default = 0.0];
}
message LinearCosineDecayConf {
required int64 decay_batches = 1;
required int64 decay_batches = 1;
optional double num_periods = 2 [default = 0.5];
optional double alpha = 3 [default = 0.0];
optional double beta = 4 [default = 0.001];
......@@ -455,7 +455,7 @@ message MaximumOpConf {
message LocalResponseNormalizationOpConf {
required string in = 1;
required string out = 2;
required string data_format = 3;
required string data_format = 3;
optional int32 depth_radius = 4 [default = 5];
optional double bias = 5 [default = 1];
optional double alpha = 6 [default = 1];
......@@ -527,7 +527,7 @@ message DecodeRandomOpConf {
required ShapeProto shape = 2;
required DataType data_type = 3;
optional int32 max_sequence_size = 4 [default = 1];
required InitializerConf distribution = 7;
required InitializerConf distribution = 7;
}
message NormalizationOpConf {
......@@ -569,7 +569,7 @@ message ReduceLocalAddOpConf {
required int32 out_num = 2;
required int64 min_in_parallel_id = 3;
required int64 min_out_parallel_id = 4;
required int64 model_elem_cnt = 5;
required int64 model_elem_cnt = 5;
};
message ReduceGlobalAddOpConf {
......@@ -587,6 +587,16 @@ message AccuracyOpConf {
required string accuracy = 4;
}
message RoIPoolingOpConf {
required string in = 1;
required string rois = 2;
required string out = 3;
required int32 pooled_h = 4;
required int32 pooled_w = 5;
required float spatial_scale = 6 [default = 0.0625]; // 1/16
}
message OperatorConf {
required string name = 1;
optional string model_load_dir = 2;
......@@ -641,6 +651,7 @@ message OperatorConf {
RecordLoadOpConf record_load_conf = 404;
AccuracyOpConf accuracy_conf=405;
AccuracyPrintOpConf accuracy_print_conf = 406;
RoIPoolingOpConf roi_pooling_conf = 407;
}
}
......
#include "oneflow/core/operator/roi_pooling_op.h"
namespace oneflow {
void RoIPoolingOp::InitFromOpConf() {
EnrollInputBn("in");
EnrollInputBn("rois", false);
EnrollOutputBn("out");
EnrollDataTmpBn("argmax");
}
const PbMessage& RoIPoolingOp::GetCustomizedConf() const { return op_conf().roi_pooling_conf(); }
void RoIPoolingOp::InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
// in
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
CHECK_EQ(in_blob_desc->shape().NumAxes(), 4);
// rois
const BlobDesc* rois_blob_desc = GetBlobDesc4BnInOp("rois");
CHECK_EQ(rois_blob_desc->shape().NumAxes(), 3);
CHECK_EQ(rois_blob_desc->shape().At(0), in_blob_desc->shape().At(0));
CHECK_EQ(rois_blob_desc->shape().At(2), 4);
// out
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape(
{in_blob_desc->shape().At(0), rois_blob_desc->shape().At(1), in_blob_desc->shape().At(1),
op_conf().roi_pooling_conf().pooled_h(), op_conf().roi_pooling_conf().pooled_w()});
out_blob_desc->set_data_type(in_blob_desc->data_type());
// argmax
BlobDesc* argmax_blob_desc = GetBlobDesc4BnInOp("argmax");
argmax_blob_desc->mut_shape() = out_blob_desc->shape();
argmax_blob_desc->set_data_type(DataType::kInt32);
}
REGISTER_OP(OperatorConf::kRoiPoolingConf, RoIPoolingOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_ROI_POOLING_OP_H_
#define ONEFLOW_CORE_OPERATOR_ROI_POOLING_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class RoIPoolingOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(RoIPoolingOp);
RoIPoolingOp() = default;
virtual ~RoIPoolingOp() = default;
const PbMessage& GetCustomizedConf() const override;
void InitFromOpConf() override;
bool NeedOutWhenBackward() const override { return false; }
void InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_ROI_POOLING_OP_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册