提交 5acae32b 编写于 作者: Z zhaojiaying01

resolve conflicts and adjust gemm code style

...@@ -106,6 +106,11 @@ const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand"; ...@@ -106,6 +106,11 @@ const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand";
const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool"; const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool";
const char *G_OP_TYPE_SEQUENCE_SOFTMAX = "sequence_softmax"; const char *G_OP_TYPE_SEQUENCE_SOFTMAX = "sequence_softmax";
const char *G_OP_TYPE_SLICE = "slice";
const char *G_OP_TYPE_ANCHOR_GENERATOR = "anchor_generator";
const char *G_OP_TYPE_GENERATE_PROPOSALS = "generate_proposals";
const char *G_OP_TYPE_PSROI_POOL = "psroi_pool";
std::unordered_map< std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = { op_input_output_key = {
...@@ -197,5 +202,11 @@ std::unordered_map< ...@@ -197,5 +202,11 @@ std::unordered_map<
{G_OP_TYPE_WRITE_TO_ARRAY, {{"X", "I"}, {"Out"}}}, {G_OP_TYPE_WRITE_TO_ARRAY, {{"X", "I"}, {"Out"}}},
{G_OP_TYPE_READ_FROM_ARRAY, {{"X", "I"}, {"Out"}}}, {G_OP_TYPE_READ_FROM_ARRAY, {{"X", "I"}, {"Out"}}},
{G_OP_TYPE_IS_EMPTY, {{"X"}, {"Out"}}}, {G_OP_TYPE_IS_EMPTY, {{"X"}, {"Out"}}},
{G_OP_TYPE_INCREMENT, {{"X"}, {"Out"}}}}; {G_OP_TYPE_INCREMENT, {{"X"}, {"Out"}}},
{G_OP_TYPE_SLICE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_ANCHOR_GENERATOR, {{"Input"}, {"Anchors", "Variances"}}},
{G_OP_TYPE_GENERATE_PROPOSALS,
{{"Scores", "BboxDeltas", "ImInfo", "Anchors", "Variances"},
{"RpnRois", "RpnRoiProbs"}}},
{G_OP_TYPE_PSROI_POOL, {{"X", "ROIs"}, {"Out"}}}};
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -194,6 +194,11 @@ extern const char *G_OP_TYPE_SEQUENCE_EXPAND; ...@@ -194,6 +194,11 @@ extern const char *G_OP_TYPE_SEQUENCE_EXPAND;
extern const char *G_OP_TYPE_SEQUENCE_POOL; extern const char *G_OP_TYPE_SEQUENCE_POOL;
extern const char *G_OP_TYPE_SEQUENCE_SOFTMAX; extern const char *G_OP_TYPE_SEQUENCE_SOFTMAX;
extern const char *G_OP_TYPE_SLICE;
extern const char *G_OP_TYPE_ANCHOR_GENERATOR;
extern const char *G_OP_TYPE_GENERATE_PROPOSALS;
extern const char *G_OP_TYPE_PSROI_POOL;
extern std::unordered_map< extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key; op_input_output_key;
......
...@@ -312,3 +312,12 @@ LOAD_OP1(is_empty, CPU); ...@@ -312,3 +312,12 @@ LOAD_OP1(is_empty, CPU);
#ifdef INCREMENT_OP #ifdef INCREMENT_OP
LOAD_OP1(increment, CPU); LOAD_OP1(increment, CPU);
#endif #endif
#ifdef ANCHOR_GENERATOR_OP
LOAD_OP1(anchor_generator, CPU);
#endif
#ifdef PROPOSAL_OP
LOAD_OP1(generate_proposals, CPU);
#endif
#ifdef PSROI_POOL_OP
LOAD_OP1(psroi_pool, CPU);
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/detection_ops.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
#ifdef ANCHOR_GENERATOR_OP
template <typename DeviceType, typename T>
void AnchorGeneratorOp<DeviceType, T>::InferShape() const {
const auto &input_dims = this->param_.input_->dims();
PADDLE_MOBILE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
const auto &anchor_sizes = this->param_.anchor_sizes_;
const auto &aspect_ratios = this->param_.aspect_ratios_;
size_t num_anchors = aspect_ratios.size() * anchor_sizes.size();
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_anchors;
dim_vec[3] = 4;
this->param_.output_anchors_->Resize(framework::make_ddim(dim_vec));
this->param_.output_variances_->Resize(framework::make_ddim(dim_vec));
}
#endif
#ifdef PROPOSAL_OP
template <typename DeviceType, typename T>
void ProposalOp<DeviceType, T>::InferShape() const {
this->param_.rpn_rois_->Resize(framework::make_ddim({-1, 4}));
this->param_.rpn_probs_->Resize(framework::make_ddim({-1, 1}));
}
#endif
#ifdef PSROI_POOL_OP
template <typename DeviceType, typename T>
void PSRoiPoolOp<DeviceType, T>::InferShape() const {
const auto &rois_dims = this->param_.input_rois_->dims();
const int pooled_height = this->param_.pooled_height_;
const int pooled_width = this->param_.pooled_width_;
const int output_channels = this->param_.output_channels_;
auto out_dims = this->param_.input_x_->dims();
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
this->param_.output_->Resize(out_dims);
}
#endif
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
#ifdef ANCHOR_GENERATOR_OP
REGISTER_OPERATOR_CPU(anchor_generator, ops::AnchorGeneratorOp);
#endif
#ifdef PROPOSAL_OP
REGISTER_OPERATOR_CPU(generate_proposals, ops::ProposalOp);
#endif
#ifdef PSROI_POOL_OP
REGISTER_OPERATOR_CPU(psroi_pool, ops::PSRoiPoolOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/detection_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#ifdef ANCHOR_GENERATOR_OP
DECLARE_OPERATOR(AnchorGenerator, AnchorGeneratorParam, AnchorGeneratorKernel);
#endif
#ifdef PROPOSAL_OP
DECLARE_OPERATOR(Proposal, ProposalParam, ProposalKernel);
#endif
#ifdef PSROI_POOL_OP
DECLARE_OPERATOR(PSRoiPool, PSRoiPoolParam, PSRoiPoolKernel);
#endif
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ANCHOR_GENERATOR_OP
#include <vector>
#include "operators/kernel/detection_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool AnchorGeneratorKernel<CPU, float>::Init(AnchorGeneratorParam<CPU> *param) {
return true;
}
template <>
void AnchorGeneratorKernel<CPU, float>::Compute(
const AnchorGeneratorParam<CPU> &param) {
// TODO(hjchen2)
}
} // namespace operators
} // namespace paddle_mobile
#endif // ANCHOR_GENERATOR_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PROPOSAL_OP
#include <vector>
#include "operators/kernel/detection_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ProposalKernel<CPU, float>::Init(ProposalParam<CPU> *param) {
return true;
}
template <>
void ProposalKernel<CPU, float>::Compute(const ProposalParam<CPU> &param) {
// TODO(hjchen2)
}
} // namespace operators
} // namespace paddle_mobile
#endif // PROPOSAL_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PSROI_POOL_OP
#include <vector>
#include "operators/kernel/detection_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool PSRoiPoolKernel<CPU, float>::Init(PSRoiPoolParam<CPU> *param) {
return true;
}
template <>
void PSRoiPoolKernel<CPU, float>::Compute(const PSRoiPoolParam<CPU> &param) {
// TODO(hjchen2)
}
} // namespace operators
} // namespace paddle_mobile
#endif // PSROI_POOL_OP
...@@ -32,11 +32,11 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) { ...@@ -32,11 +32,11 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) {
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
Tensor bias1 = *param.Bias1(); Tensor bias1 = *param.Bias1();
int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
int axis = param.Axis();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
......
...@@ -30,10 +30,11 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) { ...@@ -30,10 +30,11 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
int axis = param.Axis();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
......
...@@ -32,6 +32,8 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) { ...@@ -32,6 +32,8 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale(); Tensor new_scale = *param.NewScale();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
...@@ -115,8 +117,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) { ...@@ -115,8 +117,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
template <typename P> template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) { void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
...@@ -31,10 +31,11 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) { ...@@ -31,10 +31,11 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
int axis = param.Axis();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
......
...@@ -30,11 +30,11 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) { ...@@ -30,11 +30,11 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale(); Tensor new_scale = *param.NewScale();
Tensor *output = param.Output();
Tensor *bias1 = param.Bias(); Tensor *bias1 = param.Bias();
Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups(); int groups = param.Groups();
DLOG << "yangfei2";
DLOG << bias1->dims();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations(); std::vector<int> dilations = param.Dilations();
......
...@@ -31,6 +31,7 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) { ...@@ -31,6 +31,7 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<P>();
auto strides = param.Strides(); auto strides = param.Strides();
auto paddings = param.Paddings(); auto paddings = param.Paddings();
...@@ -76,8 +77,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) { ...@@ -76,8 +77,6 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
framework::DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; framework::DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
output->mutable_data<P>();
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups; int out_step = static_cast<int>(output->dims()[1]) / groups;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#ifdef ANCHOR_GENERATOR_OP
template <typename Dtype>
class AnchorGeneratorParam : public OpParam {
public:
AnchorGeneratorParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<framework::Tensor>("Input", inputs, scope);
output_anchors_ =
OpParam::GetVarValue<framework::Tensor>("Anchors", outputs, scope);
output_variances_ =
OpParam::GetVarValue<framework::Tensor>("Variances", outputs, scope);
anchor_sizes_ = OpParam::GetAttr<std::vector<float>>("anchor_sizes", attrs);
aspect_ratios_ =
OpParam::GetAttr<std::vector<float>>("aspect_ratios", attrs);
variances_ = OpParam::GetAttr<std::vector<float>>("variances", attrs);
stride_ = OpParam::GetAttr<std::vector<float>>("stride", attrs);
offset_ = OpParam::GetAttr<float>("offset", attrs);
}
public:
// input
framework::Tensor *input_;
// outputs
framework::Tensor *output_anchors_;
framework::Tensor *output_variances_;
std::vector<float> anchor_sizes_;
std::vector<float> aspect_ratios_;
std::vector<float> variances_;
std::vector<float> stride_;
float offset_;
};
DECLARE_KERNEL(AnchorGenerator, AnchorGeneratorParam);
#endif
#ifdef PROPOSAL_OP
template <typename Dtype>
class ProposalParam : public OpParam {
public:
ProposalParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
scores_ = OpParam::GetVarValue<framework::Tensor>("Scores", inputs, scope);
bbox_deltas_ =
OpParam::GetVarValue<framework::Tensor>("BboxDeltas", inputs, scope);
im_info_ = OpParam::GetVarValue<framework::Tensor>("ImInfo", inputs, scope);
anchors_ =
OpParam::GetVarValue<framework::Tensor>("Anchors", inputs, scope);
variances_ =
OpParam::GetVarValue<framework::Tensor>("Variances", inputs, scope);
rpn_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("RpnRois", outputs, scope);
rpn_probs_ = OpParam::GetVarValue<framework::LoDTensor>("RpnRoiProbs",
outputs, scope);
pre_nms_topn_ = OpParam::GetAttr<int>("pre_nms_topN", attrs);
post_nms_topn_ = OpParam::GetAttr<int>("post_nms_topN", attrs);
nms_thresh_ = OpParam::GetAttr<float>("nms_thresh", attrs);
min_size_ = OpParam::GetAttr<float>("min_size", attrs);
eta_ = OpParam::GetAttr<float>("eta", attrs);
}
public:
framework::Tensor *scores_;
framework::Tensor *bbox_deltas_;
framework::Tensor *im_info_;
framework::Tensor *anchors_;
framework::Tensor *variances_;
framework::LoDTensor *rpn_rois_;
framework::LoDTensor *rpn_probs_;
int pre_nms_topn_;
int post_nms_topn_;
float nms_thresh_;
float min_size_;
float eta_;
};
DECLARE_KERNEL(Proposal, ProposalParam);
#endif
#ifdef PSROI_POOL_OP
template <typename Dtype>
class PSRoiPoolParam : public OpParam {
public:
PSRoiPoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = OpParam::GetVarValue<framework::Tensor>("X", inputs, scope);
input_rois_ =
OpParam::GetVarValue<framework::LoDTensor>("ROIs", inputs, scope);
output_ = OpParam::GetVarValue<framework::Tensor>("Out", outputs, scope);
output_channels_ = OpParam::GetAttr<int>("output_channels", attrs);
pooled_height_ = OpParam::GetAttr<int>("pooled_height", attrs);
pooled_width_ = OpParam::GetAttr<int>("pooled_width", attrs);
spatial_scale_ = OpParam::GetAttr<float>("spatial_scale", attrs);
}
public:
framework::Tensor *input_x_;
framework::LoDTensor *input_rois_;
framework::Tensor *output_;
int output_channels_;
int pooled_height_;
int pooled_width_;
float spatial_scale_;
};
DECLARE_KERNEL(PSRoiPool, PSRoiPoolParam);
#endif
} // namespace operators
} // namespace paddle_mobile
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -46,15 +46,6 @@ namespace math { ...@@ -46,15 +46,6 @@ namespace math {
class Gemm { class Gemm {
public: public:
/*
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
*/
typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *);
typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *,
int); int);
...@@ -62,31 +53,31 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -62,31 +53,31 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
FnPack procPackB; FnPack procPackB;
FnAddDot procAddDot; FnAddDot procAddDot;
// 将 A 矩阵分块复制到连续内存(RowMajor) // 将 A\B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
#if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
#endif
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
...@@ -106,22 +97,16 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -106,22 +97,16 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *c, float *C, int ldc, float *p, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1); std::string mode, float *bias, float *bias1);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
/*
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float
*C, int ldc, bool relu, float *new_scale, float *new_bias);
*/
// 计算一个更小的 C 矩阵分块 // 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); #if __aarch64__
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc); void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
#endif
// 分块矩阵乘法结果回写 // 分块矩阵乘法结果回写
// C = A * B // C = A * B
...@@ -149,6 +134,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -149,6 +134,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1); float *new_scale, float *new_bias, float *bias1);
// 向量矩阵乘法 (M = 1)
#if __aarch64__
#else
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias);
// 向量矩阵乘法结果回写 // 向量矩阵乘法结果回写
// C = A * B // C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc); void VecWriteBasic(int n, float *c, float *C, int ldc);
...@@ -158,14 +155,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -158,14 +155,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void VecWriteWithAdd(int n, float *c, float *C, int ldc); void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc); void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
/* // C = A * B, batchnorm(C)
// C = A * B, batchnorm(C) void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, float *new_bias);
float *new_bias); // C = A * B, batchnorm(C), relu(C)
// C = A * B, batchnorm(C), relu(C) void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float float *new_bias);
*new_scale, float *new_bias); #endif
*/
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
......
...@@ -1521,33 +1521,20 @@ class SliceParam : public OpParam { ...@@ -1521,33 +1521,20 @@ class SliceParam : public OpParam {
public: public:
SliceParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SliceParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope); input_ = InputFrom<GType>(inputs, scope);
input_shape_ = InputShapeFrom<GType>(inputs, scope); output_ = OutFrom<GType>(outputs, scope);
out_ = OutFrom<GType>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
slice_points_ = GetAttr<vector<int>>("slice_points", attrs);
inplace_ = GetAttr<bool>("inplace", attrs);
}
const RType *InputX() const { return input_x_; }
const RType *InputShape() const { return input_shape_; }
RType *Out() const { return out_; }
const int &Axis() const { return axis_; }
const vector<int> &SlicePoints() const { return slice_points_; }
const bool &Inplace() const { return inplace_; } axes_ = GetAttr<std::vector<int>>("axes", attrs);
starts_ = GetAttr<std::vector<int>>("starts", attrs);
ends_ = GetAttr<std::vector<int>>("ends", attrs);
}
private: public:
RType *input_x_; GType *input_;
RType *input_shape_; GType *output_;
RType *out_; std::vector<int> axes_;
int axis_; std::vector<int> starts_;
vector<int> slice_points_; std::vector<int> ends_;
bool inplace_;
}; };
#endif #endif
......
...@@ -290,6 +290,9 @@ if(NOT FOUND_MATCH) ...@@ -290,6 +290,9 @@ if(NOT FOUND_MATCH)
set(READ_FROM_ARRAY_OP ON) set(READ_FROM_ARRAY_OP ON)
set(IS_EMPTY_OP ON) set(IS_EMPTY_OP ON)
set(INCREMENT_OP ON) set(INCREMENT_OP ON)
set(ANCHOR_GENERATOR_OP ON)
set(PROPOSAL_OP ON)
set(PSROI_POOL_OP ON)
endif() endif()
# option(BATCHNORM_OP "" ON) # option(BATCHNORM_OP "" ON)
...@@ -579,4 +582,14 @@ if (IS_EMPTY_OP) ...@@ -579,4 +582,14 @@ if (IS_EMPTY_OP)
endif() endif()
if (INCREMENT_OP) if (INCREMENT_OP)
add_definitions(-DINCREMENT_OP) add_definitions(-DINCREMENT_OP)
endif() endif()
\ No newline at end of file
if (ANCHOR_GENERATOR_OP)
add_definitions(-DANCHOR_GENERATOR_OP)
endif()
if (PROPOSAL_OP)
add_definitions(-DPROPOSAL_OP)
endif()
if (PSROI_POOL_OP)
add_definitions(-DPSROI_POOL_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册