提交 ea2bdd19 编写于 作者: T Tao Luo

Merge branch 'develop' into remove_unused_code

...@@ -21,7 +21,7 @@ else ...@@ -21,7 +21,7 @@ else
fi fi
USE_TENSORRT=OFF USE_TENSORRT=OFF
if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then if [ -d "$TENSORRT_INCLUDE_DIR" -a -d "$TENSORRT_LIB_DIR" ]; then
USE_TENSORRT=ON USE_TENSORRT=ON
fi fi
......
...@@ -42,16 +42,22 @@ class Pool2dOpConverter : public OpConverter { ...@@ -42,16 +42,22 @@ class Pool2dOpConverter : public OpConverter {
boost::get<std::vector<int>>(op_desc.GetAttr("strides")); boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
std::vector<int> paddings = std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
if (global_pooling == true) { if (global_pooling == true) {
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nv_ksize.d[0] = input_shape.d[nbDims - 2]; nv_ksize.d[0] = input_shape.d[nbDims - 2];
nv_ksize.d[1] = input_shape.d[nbDims - 1]; nv_ksize.d[1] = input_shape.d[nbDims - 1];
nv_strides.h() = 1;
nv_strides.w() = 1;
nv_paddings.h() = 0;
nv_paddings.w() = 0;
} }
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL); PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL);
...@@ -64,6 +70,36 @@ class Pool2dOpConverter : public OpConverter { ...@@ -64,6 +70,36 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_THROW("TensorRT unsupported pooling type!"); PADDLE_THROW("TensorRT unsupported pooling type!");
} }
if (ceil_mode) {
nvinfer1::DimsHW pre_pad(0, 0);
nvinfer1::DimsHW post_pad(0, 0);
int input_height = input_shape.d[nbDims - 2];
int input_width = input_shape.d[nbDims - 1];
int floor_h_output_size =
(input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
int ceil_h_output_size =
(input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) /
strides[0] +
1;
int floor_w_output_size =
(input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
int ceil_w_output_size =
(input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) /
strides[1] +
1;
if (floor_h_output_size != ceil_h_output_size) {
post_pad.h() = strides[0] - 1;
}
if (floor_w_output_size != ceil_w_output_size) {
post_pad.w() = strides[1] - 1;
}
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Padding, *const_cast<nvinfer1::ITensor*>(input1), pre_pad,
post_pad);
input1 = layer->getOutput(0);
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling,
*const_cast<nvinfer1::ITensor*>(input1), *const_cast<nvinfer1::ITensor*>(input1),
nv_pool_type, nv_ksize); nv_pool_type, nv_ksize);
......
...@@ -20,18 +20,20 @@ namespace paddle { ...@@ -20,18 +20,20 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
void test_pool2d(bool global_pooling) { void test_pool2d(bool global_pooling, bool ceil_mode) {
framework::Scope scope; framework::Scope scope;
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
TRTConvertValidation validator(5, parameters, scope, 1 << 15); TRTConvertValidation validator(5, parameters, scope, 1 << 15);
// The ITensor's Dims should not contain the batch size. // The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W. // So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 13, 14));
if (global_pooling) if (global_pooling)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
else if (ceil_mode)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 7));
else else
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 6));
// Prepare Op description // Prepare Op description
framework::OpDesc desc; framework::OpDesc desc;
...@@ -39,7 +41,7 @@ void test_pool2d(bool global_pooling) { ...@@ -39,7 +41,7 @@ void test_pool2d(bool global_pooling) {
desc.SetInput("X", {"pool2d-X"}); desc.SetInput("X", {"pool2d-X"});
desc.SetOutput("Out", {"pool2d-Out"}); desc.SetOutput("Out", {"pool2d-Out"});
std::vector<int> ksize({2, 2}); std::vector<int> ksize({3, 3});
std::vector<int> strides({2, 2}); std::vector<int> strides({2, 2});
std::vector<int> paddings({0, 0}); std::vector<int> paddings({0, 0});
std::string pooling_t = "max"; std::string pooling_t = "max";
...@@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) { ...@@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) {
desc.SetAttr("strides", strides); desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings); desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling); desc.SetAttr("global_pooling", global_pooling);
desc.SetAttr("ceil_mode", ceil_mode);
LOG(INFO) << "set OP"; LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
...@@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) { ...@@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) {
validator.Execute(3); validator.Execute(3);
} }
TEST(Pool2dOpConverter, normal) { test_pool2d(false); } TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); } TEST(Pool2dOpConverter, test_ceil_mode) { test_pool2d(false, true); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("TargetBBox"), ctx->HasOutput("TargetBBox"),
"Output(TargetBBox) of RpnTargetAssignOp should not be null"); "Output(TargetBBox) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BBoxInsideWeight"),
"Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null");
auto anchor_dims = ctx->GetInputDim("Anchor"); auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
...@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ScoreIndex", {-1}); ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1}); ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4}); ctx->SetOutputDim("TargetBBox", {-1, 4});
ctx->SetOutputDim("BBoxInsideWeight", {-1, 4});
} }
protected: protected:
...@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
const float rpn_positive_overlap, const float rpn_positive_overlap,
const float rpn_negative_overlap, std::vector<int>* fg_inds, const float rpn_negative_overlap, std::vector<int>* fg_inds,
std::vector<int>* bg_inds, std::vector<int>* tgt_lbl, std::vector<int>* bg_inds, std::vector<int>* tgt_lbl,
std::vector<int>* fg_fake, std::vector<T>* bbox_inside_weight,
std::minstd_rand engine, bool use_random) { std::minstd_rand engine, bool use_random) {
float epsilon = 0.00001; float epsilon = 0.00001;
int anchor_num = anchor_to_gt_max.dims()[0]; int anchor_num = anchor_to_gt_max.dims()[0];
...@@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
// Reservoir Sampling // Reservoir Sampling
int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im); int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im);
ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random);
fg_num = static_cast<int>(fg_inds_fake.size()); int fg_fake_num = static_cast<int>(fg_inds_fake.size());
for (int64_t i = 0; i < fg_num; ++i) { for (int64_t i = 0; i < fg_fake_num; ++i) {
target_label[fg_inds_fake[i]] = 1; target_label[fg_inds_fake[i]] = 1;
} }
int bg_num = rpn_batch_size_per_im - fg_num; int bg_num = rpn_batch_size_per_im - fg_fake_num;
for (int64_t i = 0; i < anchor_num; ++i) { for (int64_t i = 0; i < anchor_num; ++i) {
if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) {
bg_inds_fake.push_back(i); bg_inds_fake.push_back(i);
...@@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
} }
ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random);
bg_num = static_cast<int>(bg_inds_fake.size()); bg_num = static_cast<int>(bg_inds_fake.size());
int fake_num = 0;
for (int64_t i = 0; i < bg_num; ++i) { for (int64_t i = 0; i < bg_num; ++i) {
// fg fake found
if (target_label[bg_inds_fake[i]] == 1) {
fake_num++;
fg_fake->emplace_back(fg_inds_fake[0]);
for (int j = 0; j < 4; ++j) {
bbox_inside_weight->emplace_back(T(0.));
}
}
target_label[bg_inds_fake[i]] = 0; target_label[bg_inds_fake[i]] = 0;
} }
for (int64_t i = 0; i < (fg_fake_num - fake_num) * 4; ++i) {
bbox_inside_weight->emplace_back(T(1.));
}
for (int64_t i = 0; i < anchor_num; ++i) { for (int64_t i = 0; i < anchor_num; ++i) {
if (target_label[i] == 1) fg_inds->emplace_back(i); if (target_label[i] == 1) {
fg_inds->emplace_back(i);
fg_fake->emplace_back(i);
}
if (target_label[i] == 0) bg_inds->emplace_back(i); if (target_label[i] == 0) bg_inds->emplace_back(i);
} }
fg_num = fg_inds->size(); fg_num = fg_inds->size();
...@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, ...@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
std::vector<int> bg_inds; std::vector<int> bg_inds;
std::vector<int> gt_inds; std::vector<int> gt_inds;
std::vector<int> tgt_lbl; std::vector<int> tgt_lbl;
std::vector<int> fg_fake;
std::vector<T> bbox_inside_weight;
// Calculate the max IoU between anchors and gt boxes // Calculate the max IoU between anchors and gt boxes
// Map from anchor to gt box that has highest overlap // Map from anchor to gt box that has highest overlap
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, ...@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
// Follow the Faster RCNN's implementation // Follow the Faster RCNN's implementation
ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max,
rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap, rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap,
rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, engine, rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, &fg_fake,
use_random); &bbox_inside_weight, engine, use_random);
int fg_num = fg_inds.size(); int fg_num = fg_inds.size();
int bg_num = bg_inds.size(); int bg_num = bg_inds.size();
gt_inds.reserve(fg_num); int fg_fake_num = fg_fake.size();
for (int i = 0; i < fg_num; ++i) { gt_inds.reserve(fg_fake_num);
gt_inds.emplace_back(argmax[fg_inds[i]]); for (int i = 0; i < fg_fake_num; ++i) {
gt_inds.emplace_back(argmax[fg_fake[i]]);
} }
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t;
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t; int* loc_index_data = loc_index_t.mutable_data<int>({fg_fake_num}, place);
int* loc_index_data = loc_index_t.mutable_data<int>({fg_num}, place);
int* score_index_data = int* score_index_data =
score_index_t.mutable_data<int>({fg_num + bg_num}, place); score_index_t.mutable_data<int>({fg_num + bg_num}, place);
int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place); int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place);
int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_num}, place); int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_fake_num}, place);
std::copy(fg_inds.begin(), fg_inds.end(), loc_index_data); T* bbox_inside_weight_data =
bbox_inside_weight_t.mutable_data<T>({fg_fake_num, 4}, place);
std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data);
std::copy(fg_inds.begin(), fg_inds.end(), score_index_data); std::copy(fg_inds.begin(), fg_inds.end(), score_index_data);
std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num); std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num);
std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data); std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data);
std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data); std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data);
std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(),
bbox_inside_weight_data);
std::vector<Tensor> loc_score_tgtlbl_gt; std::vector<Tensor> loc_score_tgtlbl_gt;
loc_score_tgtlbl_gt.emplace_back(loc_index_t); loc_score_tgtlbl_gt.emplace_back(loc_index_t);
loc_score_tgtlbl_gt.emplace_back(score_index_t); loc_score_tgtlbl_gt.emplace_back(score_index_t);
loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t); loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t);
loc_score_tgtlbl_gt.emplace_back(gt_inds_t); loc_score_tgtlbl_gt.emplace_back(gt_inds_t);
loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t);
return loc_score_tgtlbl_gt; return loc_score_tgtlbl_gt;
} }
...@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
auto* score_index = context.Output<LoDTensor>("ScoreIndex"); auto* score_index = context.Output<LoDTensor>("ScoreIndex");
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox"); auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel"); auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"RpnTargetAssignOp gt_boxes needs 1 level of LoD"); "RpnTargetAssignOp gt_boxes needs 1 level of LoD");
...@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
score_index->mutable_data<int>({max_num}, place); score_index->mutable_data<int>({max_num}, place);
tgt_bbox->mutable_data<T>({max_num, 4}, place); tgt_bbox->mutable_data<T>({max_num, 4}, place);
tgt_lbl->mutable_data<int>({max_num, 1}, place); tgt_lbl->mutable_data<int>({max_num, 1}, place);
bbox_inside_weight->mutable_data<T>({max_num, 4}, place);
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>(); auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
std::random_device rnd; std::random_device rnd;
...@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
Tensor sampled_score_index = loc_score_tgtlbl_gt[1]; Tensor sampled_score_index = loc_score_tgtlbl_gt[1];
Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2]; Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2];
Tensor sampled_gt_index = loc_score_tgtlbl_gt[3]; Tensor sampled_gt_index = loc_score_tgtlbl_gt[3];
Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4];
int loc_num = sampled_loc_index.dims()[0]; int loc_num = sampled_loc_index.dims()[0];
int score_num = sampled_score_index.dims()[0]; int score_num = sampled_score_index.dims()[0];
...@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap); AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap);
AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox); AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox);
AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl); AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl);
AppendRpns<T>(bbox_inside_weight, total_loc_num * 4,
&sampled_bbox_inside_weight);
total_loc_num += loc_num; total_loc_num += loc_num;
total_score_num += score_num; total_score_num += score_num;
...@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
score_index->set_lod(loc_score); score_index->set_lod(loc_score);
tgt_bbox->set_lod(lod_loc); tgt_bbox->set_lod(lod_loc);
tgt_lbl->set_lod(loc_score); tgt_lbl->set_lod(loc_score);
bbox_inside_weight->set_lod(lod_loc);
loc_index->Resize({total_loc_num}); loc_index->Resize({total_loc_num});
score_index->Resize({total_score_num}); score_index->Resize({total_score_num});
tgt_bbox->Resize({total_loc_num, 4}); tgt_bbox->Resize({total_loc_num, 4});
tgt_lbl->Resize({total_score_num, 1}); tgt_lbl->Resize({total_score_num, 1});
bbox_inside_weight->Resize({total_loc_num, 4});
} }
}; };
...@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"TargetLabel", "TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape " "(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number."); "[F + B, 1], F and B are sampled foreground and backgroud number.");
AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number.");
AddComment(R"DOC( AddComment(R"DOC(
This operator can be, for a given set of ground truth bboxes and the This operator can be, for a given set of ground truth bboxes and the
anchors, to assign classification and regression targets to each prediction. anchors, to assign classification and regression targets to each prediction.
......
...@@ -16,10 +16,9 @@ limitations under the License. */ ...@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
} }
#define INIT_VEC_FUNC \ #define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_state; \ auto* x = ctx.Input<LoDTensor>("X"); \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \ auto x_lod = x->lod(); \
if (platform::jit::MayIUse(platform::jit::avx)) { \ auto x_dims = x->dims(); /* T x M*/ \
math::VecActivations<T, platform::jit::avx> act_functor; \ auto wh_dims = wh->dims(); /* D x 3D*/ \
act_gate = act_functor(act_gate_str); \ const int total_T = x_dims[0]; \
act_state = act_functor(act_state_str); \ const int D3 = wh_dims[1]
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \ #define INIT_OTHER_DEFINES \
math::VecActivations<T, platform::jit::isa_any> act_functor; \ auto* h0 = ctx.Input<Tensor>("H0"); \
act_gate = act_functor(act_gate_str); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
act_state = act_functor(act_state_str); \ auto* bias = ctx.Input<Tensor>("Bias"); \
cross = math::vec_cross<T, platform::jit::isa_any>; \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
} bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
#define INIT_BASE_INPUT_OUTPUT \ const int D = wh_dims[0]; \
auto* h0 = ctx.Input<Tensor>("H0"); \ const int D2 = D * 2; \
auto* wx = ctx.Input<Tensor>("WeightX"); \ const auto& ker = math::jitkernel::KernelPool::Instance() \
auto* wh = ctx.Input<Tensor>("WeightH"); \ .template Get<math::jitkernel::GRUKernel<T>, \
auto* bias = ctx.Input<Tensor>("Bias"); \ const std::string&, const std::string&>( \
auto* xx = ctx.Output<LoDTensor>("XX"); \ ctx.Attr<std::string>("gate_activation"), \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ ctx.Attr<std::string>("activation"), D); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
#define INIT_BASE_SIZES \ const T* wh_data = wh->data<T>(); \
auto x_dims = x->dims(); /* T x M*/ \ auto place = ctx.GetPlace(); \
auto wh_dims = wh->dims(); /* D x 3D*/ \ T* xx_data = xx->mutable_data<T>(place)
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT INIT_OTHER_DEFINES;
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(place);
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, xx_data,
...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
// W: {W_update, W_reset; W_state} ker->ComputeH1(xx_data, hidden_out_data);
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
act_gate(D2, xx_data, xx_data); ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2); ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT if (x_lod[0].size() == 2) {
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3}); xx->Resize({total_T, D3});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_VEC_FUNC INIT_OTHER_DEFINES;
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut"); auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T* batched_input_data = batched_input->mutable_data<T>(place);
const T* x_data = x->data<T>(); T* batched_out_data = batched_out->mutable_data<T>(place);
const T* wx_data = wx->data<T>(); hidden_out->mutable_data<T>(place);
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* prev_hidden_data = nullptr; T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
// reorder h0 // reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace()); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
prev_hidden_data = reordered_h0_data; prev_hidden_data = reordered_h0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
// update gate ker->ComputeH1(cur_in_data, cur_out_data);
act_gate(D, cur_in_data, cur_in_data);
// state gate
act_state(D, cur_in_data + D2, cur_in_data + D2);
// out = a*b
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data); ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
// rt = rt*ht_1 inplace result cur_out_data);
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...) ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
act_state(D, cur_batched_data + D2, cur_batched_data + D2); cur_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC #undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES #undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
......
...@@ -75,6 +75,6 @@ endif() ...@@ -75,6 +75,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
DEPS cpu_info cblas) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { ...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr) const = 0; const T *wp_data = nullptr) const = 0;
}; };
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0;
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
};
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \ const std::string& act_cell, int d) \
: LSTMKernel<float>() { \ : LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \ } \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \ template <> \
} else if (type == "tanh") { \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \ float* gates, const float* ct_1, float* ct, float* ht, \
} else if (type == "identity" || type == "") { \ const float* wp_data, float* checked) const { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \ /* gates: W_ch, W_ih, W_fh, W_oh */ \
} \ __m256 c, i, f, o; \
PADDLE_THROW("Not support type: %s", type); \ c = _mm256_loadu_ps(gates); \
}; \ i = _mm256_loadu_ps(gates + 8); \
avx_act_gate_ = GetAVXAct(act_gate); \ f = _mm256_loadu_ps(gates + 16); \
avx_act_cand_ = GetAVXAct(act_cand); \ o = _mm256_loadu_ps(gates + 24); \
avx_act_cell_ = GetAVXAct(act_cell); \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
} \ c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
template <> \ i = _mm256_loadu_ps(ct_1); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
float* gates, const float* ct_1, float* ct, float* ht, \ f = _mm256_add_ps(c, f); \
const float* wp_data, float* checked) const { \ _mm256_storeu_ps(ct, f); \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ /* H_t = act_cell(C_t) * ogated */ \
__m256 c, i, f, o; \ o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
c = _mm256_loadu_ps(gates); \ _mm256_storeu_ps(ht, o); \
i = _mm256_loadu_ps(gates + 8); \ } \
f = _mm256_loadu_ps(gates + 16); \ template <> \
o = _mm256_loadu_ps(gates + 24); \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ float* gates, float* ct, float* ht, const float* wp_data) const { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ __m256 c, i, o; \
i = _mm256_loadu_ps(ct_1); \ c = _mm256_loadu_ps(gates); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ i = _mm256_loadu_ps(gates + 8); \
f = _mm256_add_ps(c, f); \ o = _mm256_loadu_ps(gates + 24); \
_mm256_storeu_ps(ct, f); \ /* C_t = igated * cgated*/ \
/* H_t = act_cell(C_t) * ogated */ \ c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ _mm256_storeu_ps(ct, c); \
_mm256_storeu_ps(ht, o); \ /* H_t = act_cell(C_t) * ogated */ \
} \ o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
template <> \ _mm256_storeu_ps(ht, o); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
...@@ -354,6 +357,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, ...@@ -354,6 +357,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class GRUKernelImpl : public GRUKernel<T> {
public:
explicit GRUKernelImpl(const std::string& act_gate,
const std::string& act_state, int d)
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k, const T* src, int lds, int dim, int k,
int grid_dim, int num) { int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize]; __shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int warp = threadIdx.x / 32; const int warp = threadIdx.x / 32;
const int bid = blockIdx.x; const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) { for (int i = bid; i < num; i += grid_dim) {
output += i * output_stride; int top_num = k;
indices += i * k; __shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength]; Pair<T> topk[MaxLength];
int beam = MaxLength; int beam = MaxLength;
Pair<T> max; Pair<T> max;
bool is_empty = false; bool is_empty = false;
bool firststep = true; bool firststep = true;
for (int k = 0; k < MaxLength; k++) { for (int j = 0; j < MaxLength; j++) {
topk[k].set(-INFINITY, -1); topk[j].set(-INFINITY, -1);
} }
while (k) { while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>( ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
sh_topk[tid] = topk[0]; sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output, BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&indices, &beam, &k, tid, warp); &beam, &top_num, tid, warp);
} }
} }
} }
...@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t k = static_cast<int>(ctx.Attr<int>("k")); size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0]; framework::DDim inputdims = input->dims();
size_t input_width = input->dims()[1]; const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
...@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) { switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5, KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>( kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data, output_data, k, indices_data, input_data, input_width,
input_width, input_width, static_cast<int>(k), gridx, input_width, static_cast<int>(k), gridx, input_height));
input_height));
default: default:
PADDLE_THROW("Error"); PADDLE_THROW("Error");
} }
......
...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
auto eg_input = EigenMatrix<T>::From(*input);
// reshape input to a flattern matrix(like flat_inner_dims) // reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t row = framework::product( const size_t row = framework::product(
...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
const size_t col = inputdims[inputdims.size() - 1]; const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col); Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor. // NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims); auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
......
...@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred, ...@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred,
Returns: Returns:
tuple: tuple:
A tuple(predicted_scores, predicted_location, target_label, A tuple(predicted_scores, predicted_location, target_label,
target_bbox) is returned. The predicted_scores and target_bbox, bbox_inside_weight) is returned. The predicted_scores
predicted_location is the predicted result of the RPN. and predicted_location is the predicted result of the RPN.
The target_label and target_bbox is the ground truth, The target_label and target_bbox is the ground truth,
respectively. The predicted_location is a 2D Tensor with shape respectively. The predicted_location is a 2D Tensor with shape
[F, 4], and the shape of target_bbox is same as the shape of [F, 4], and the shape of target_bbox is same as the shape of
...@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred, ...@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred,
[F + B, 1], and the shape of target_label is same as the shape [F + B, 1], and the shape of target_label is same as the shape
of the predicted_scores, B is the number of the background of the predicted_scores, B is the number of the background
anchors, the F and B is depends on the input of this operator. anchors, the F and B is depends on the input of this operator.
Bbox_inside_weight represents whether the predicted loc is fake_fg
or not and the shape is [F, 4].
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred, ...@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred,
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4], gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target = loc_pred, score_pred, loc_target, score_target, bbox_inside_weight =
fluid.layers.rpn_target_assign(bbox_pred=bbox_pred, fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
cls_logits=cls_logits, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
...@@ -152,6 +154,8 @@ def rpn_target_assign(bbox_pred, ...@@ -152,6 +154,8 @@ def rpn_target_assign(bbox_pred,
target_label = helper.create_variable_for_type_inference(dtype='int32') target_label = helper.create_variable_for_type_inference(dtype='int32')
target_bbox = helper.create_variable_for_type_inference( target_bbox = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype) dtype=anchor_box.dtype)
bbox_inside_weight = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
helper.append_op( helper.append_op(
type="rpn_target_assign", type="rpn_target_assign",
inputs={ inputs={
...@@ -164,7 +168,8 @@ def rpn_target_assign(bbox_pred, ...@@ -164,7 +168,8 @@ def rpn_target_assign(bbox_pred,
'LocationIndex': loc_index, 'LocationIndex': loc_index,
'ScoreIndex': score_index, 'ScoreIndex': score_index,
'TargetLabel': target_label, 'TargetLabel': target_label,
'TargetBBox': target_bbox 'TargetBBox': target_bbox,
'BBoxInsideWeight': bbox_inside_weight
}, },
attrs={ attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
...@@ -179,13 +184,14 @@ def rpn_target_assign(bbox_pred, ...@@ -179,13 +184,14 @@ def rpn_target_assign(bbox_pred,
score_index.stop_gradient = True score_index.stop_gradient = True
target_label.stop_gradient = True target_label.stop_gradient = True
target_bbox.stop_gradient = True target_bbox.stop_gradient = True
bbox_inside_weight.stop_gradient = True
cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1)) cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4)) bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_cls_logits = nn.gather(cls_logits, score_index) predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_bbox_pred = nn.gather(bbox_pred, loc_index) predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
def detection_output(loc, def detection_output(loc,
......
...@@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase):
dtype='float32', dtype='float32',
lod_level=1, lod_level=1,
append_batch_size=False) append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign(
bbox_pred=bbox_pred, bbox_pred=bbox_pred,
cls_logits=cls_logits, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
...@@ -313,15 +313,18 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -313,15 +313,18 @@ class TestRpnTargetAssign(unittest.TestCase):
rpn_straddle_thresh=0.0, rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5, rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3) rpn_negative_overlap=0.3,
use_random=False)
self.assertIsNotNone(pred_scores) self.assertIsNotNone(pred_scores)
self.assertIsNotNone(pred_loc) self.assertIsNotNone(pred_loc)
self.assertIsNotNone(tgt_lbl) self.assertIsNotNone(tgt_lbl)
self.assertIsNotNone(tgt_bbox) self.assertIsNotNone(tgt_bbox)
self.assertIsNotNone(bbox_inside_weight)
assert pred_scores.shape[1] == 1 assert pred_scores.shape[1] == 1
assert pred_loc.shape[1] == 4 assert pred_loc.shape[1] == 4
assert pred_loc.shape[1] == tgt_bbox.shape[1] assert pred_loc.shape[1] == tgt_bbox.shape[1]
print(str(program))
class TestGenerateProposals(unittest.TestCase): class TestGenerateProposals(unittest.TestCase):
......
...@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase): ...@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False self._use_reduce = False
def test_dist_train(self): # FIXME(typhoonzero): fix async mode test later
def no_test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
......
...@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reader_alloc = False self._use_reader_alloc = False
def test_dist_train(self): #FIXME(typhoonzero): fix async mode later
def no_test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
...@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): ...@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_simnet_bow(self): #FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '0', "IS_SPARSE": '0',
...@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_simnet_bow(self): #FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
......
...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): ...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self.D = 8 self.D = 8
class TestFusionGRUOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
class TestFusionGRUOpBS1(TestFusionGRUOp): class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self): def set_confs(self):
self.lod = [[3]] self.lod = [[3]]
......
...@@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap, ...@@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap,
fg_inds, size=(len(fg_inds) - num_fg), replace=False) fg_inds, size=(len(fg_inds) - num_fg), replace=False)
else: else:
disable_inds = fg_inds[num_fg:] disable_inds = fg_inds[num_fg:]
labels[disable_inds] = -1 labels[disable_inds] = -1
fg_inds = np.where(labels == 1)[0] fg_inds = np.where(labels == 1)[0]
bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32)
num_bg = rpn_batch_size_per_im - np.sum(labels == 1) num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
...@@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap, ...@@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap,
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
else: else:
enable_inds = bg_inds[:num_bg] enable_inds = bg_inds[:num_bg]
fg_fake_inds = np.array([], np.int32)
fg_value = np.array([fg_inds[0]], np.int32)
fake_num = 0
for bg_id in enable_inds:
if bg_id in fg_inds:
fake_num += 1
fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
labels[enable_inds] = 0 labels[enable_inds] = 0
bbox_inside_weight[fake_num:, :] = 1
fg_inds = np.where(labels == 1)[0] fg_inds = np.where(labels == 1)[0]
bg_inds = np.where(labels == 0)[0] bg_inds = np.where(labels == 0)[0]
loc_index = np.hstack([fg_fake_inds, fg_inds])
loc_index = fg_inds score_index = np.hstack([fg_inds, bg_inds])
score_index = np.hstack((fg_inds, bg_inds))
labels = labels[score_index] labels = labels[score_index]
assert not np.any(labels == -1), "Wrong labels with -1" assert not np.any(labels == -1), "Wrong labels with -1"
gt_inds = anchor_to_gt_argmax[fg_inds] gt_inds = anchor_to_gt_argmax[loc_index]
return loc_index, score_index, labels, gt_inds return loc_index, score_index, labels, gt_inds, bbox_inside_weight
def get_anchor(n, c, h, w): def get_anchor(n, c, h, w):
...@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors,
gt_boxes_slice = gt_boxes_slice[not_crowd_inds] gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
iou = _bbox_overlaps(inside_anchors, gt_boxes_slice) iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
loc_inds, score_inds, labels, gt_inds = rpn_target_assign( loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = \
iou, rpn_batch_size_per_im, rpn_positive_overlap, rpn_target_assign(iou, rpn_batch_size_per_im,
rpn_negative_overlap, rpn_fg_fraction, use_random) rpn_positive_overlap,
rpn_negative_overlap,
rpn_fg_fraction,
use_random)
# unmap to all anchor # unmap to all anchor
loc_inds = inds_inside[loc_inds] loc_inds = inds_inside[loc_inds]
score_inds = inds_inside[score_inds] score_inds = inds_inside[score_inds]
...@@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors,
score_indexes = score_inds score_indexes = score_inds
tgt_labels = labels tgt_labels = labels
tgt_bboxes = box_deltas tgt_bboxes = box_deltas
bbox_inside_weights = bbox_inside_weight
else: else:
loc_indexes = np.concatenate( loc_indexes = np.concatenate(
[loc_indexes, loc_inds + i * anchor_num]) [loc_indexes, loc_inds + i * anchor_num])
...@@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors,
[score_indexes, score_inds + i * anchor_num]) [score_indexes, score_inds + i * anchor_num])
tgt_labels = np.concatenate([tgt_labels, labels]) tgt_labels = np.concatenate([tgt_labels, labels])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
bbox_inside_weights = np.vstack([bbox_inside_weights, \
bbox_inside_weight])
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights
class TestRpnTargetAssignOp(OpTest): class TestRpnTargetAssignOp(OpTest):
...@@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest):
rpn_fg_fraction = 0.5 rpn_fg_fraction = 0.5
use_random = False use_random = False
loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python( loc_index, score_index, tgt_bbox, labels, bbox_inside_weights = \
all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh, rpn_target_assign_in_python(all_anchors, gt_boxes, is_crowd,
rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap, im_info, lod, rpn_straddle_thresh,
rpn_fg_fraction, use_random) rpn_batch_size_per_im, rpn_positive_overlap,
rpn_negative_overlap,
rpn_fg_fraction, use_random)
labels = labels[:, np.newaxis] labels = labels[:, np.newaxis]
self.op_type = "rpn_target_assign" self.op_type = "rpn_target_assign"
...@@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest):
'LocationIndex': loc_index.astype('int32'), 'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index.astype('int32'), 'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'), 'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32') 'TargetLabel': labels.astype('int32'),
'BBoxInsideWeight': bbox_inside_weights.astype('float32')
} }
def test_check_output(self): def test_check_output(self):
......
...@@ -21,22 +21,27 @@ from op_test import OpTest ...@@ -21,22 +21,27 @@ from op_test import OpTest
class TestTopkOp(OpTest): class TestTopkOp(OpTest):
def setUp(self): def setUp(self):
self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
k = 1 k = self.top_k
input = np.random.random((32, 84)).astype("float32") input = np.random.random((self.row, k)).astype("float32")
output = np.ndarray((32, k)) output = np.ndarray((self.row, k))
indices = np.ndarray((32, k)).astype("int64") indices = np.ndarray((self.row, k)).astype("int64")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = {'k': k} self.attrs = {'k': k}
for rowid in range(32): for rowid in range(self.row):
row = input[rowid] row = input[rowid]
output[rowid] = np.sort(row)[-k:] output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[-k:] indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
def set_args(self):
self.row = 32
self.top_k = 1
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest): ...@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest):
output = np.ndarray((64, k)) output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int64") indices = np.ndarray((64, k)).astype("int64")
# FIXME: should use 'X': input for a 3d input self.inputs = {'X': input}
self.inputs = {'X': input_flat_2d}
self.attrs = {'k': k} self.attrs = {'k': k}
for rowid in range(64): for rowid in range(64):
row = input_flat_2d[rowid] row = input_flat_2d[rowid]
output[rowid] = np.sort(row)[-k:] output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[-k:] indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {
'Out': output.reshape((32, 2, k)),
'Indices': indices.reshape((32, 2, k))
}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 1
m = 2056
input = np.random.random((m, 84)).astype("float32")
output = np.ndarray((m, k))
indices = np.ndarray((m, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(m):
row = input[rowid]
output[rowid] = -np.sort(-row)[:k]
indices[rowid] = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
...@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest): ...@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest):
self.check_output() self.check_output()
class TestTopkOp3(TestTopkOp):
def set_args(self):
self.row = 2056
self.top_k = 3
class TestTopkOp4(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册