提交 7421f560 编写于 作者: Y yangfei

imp fusion_conv_add_prelu and fusion_conv_add_add_prelu op

上级 a1a7b05b
...@@ -18,33 +18,33 @@ limitations under the License. */ ...@@ -18,33 +18,33 @@ limitations under the License. */
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename Dtype, typename T> template <typename Dtype, typename T>
void FusionConvAddAddPReluOp<Dtype, T>::InferShape() const { void FusionConvAddAddPReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims(); auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims(); auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides(); const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings(); std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups(); int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations(); std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() && dilations.size() == paddings.size() &&
paddings.size() == strides.size()), paddings.size() == strides.size()),
"ConvParam is not suitable"); "ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back( output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i])); paddings[i], strides[i]));
} }
framework::DDim ddim = framework::make_ddim(output_shape); framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim); this->param_.Output()->Resize(ddim);
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
......
...@@ -24,62 +24,64 @@ limitations under the License. */ ...@@ -24,62 +24,64 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
class FusionConvAddAddPReluOpMatcher : public framework::FusionOpMatcher { class FusionConvAddAddPReluOpMatcher : public framework::FusionOpMatcher {
public: public:
FusionConvAddAddPReluOpMatcher() { FusionConvAddAddPReluOpMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV); node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) > node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
> std::make_shared<framework::Node>(G_OP_TYPE_PRELU); std::make_shared<framework::Node>(G_OP_TYPE_PRELU);
} }
void FolderNodes( void FolderNodes(
framework::Node *node, framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) { std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}, {"Out", "addOut"},{"X", "addX"}}}, {{G_OP_TYPE_ELEMENTWISE_ADD,
{G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}} {{"Y", "Y"}, {"Out", "addOut"}, {"X", "addX"}}},
}, {G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}}},
removed_nodes); removed_nodes);
} }
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; } std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; }
std::vector<std::pair<int, std::string>> NeedCheck() { std::vector<std::pair<int, std::string>> NeedCheck() {
DLOG << " conv add add prelu check add X "; DLOG << " conv add add prelu check add X ";
return {{2, "Y"}, {2, "X"}}; return {{2, "Y"}, {2, "X"}};
} }
}; };
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionConvAddAddPReluOp : public framework::OperatorWithKernel< class FusionConvAddAddPReluOp
DeviceType, FusionConvAddAddPReluParam<DeviceType>, : public framework::OperatorWithKernel<
operators::ConvAddAddPReluKernel<DeviceType, T>> { DeviceType, FusionConvAddAddPReluParam<DeviceType>,
public: operators::ConvAddAddPReluKernel<DeviceType, T>> {
FusionConvAddAddPReluOp(const string &type, const VariableNameMap &inputs, public:
const VariableNameMap &outputs, FusionConvAddAddPReluOp(const string &type, const VariableNameMap &inputs,
const framework::AttributeMap &attrs, const VariableNameMap &outputs,
std::shared_ptr<framework::Scope> scope) const framework::AttributeMap &attrs,
: framework::OperatorWithKernel< std::shared_ptr<framework::Scope> scope)
DeviceType, FusionConvAddAddPReluParam<DeviceType>, : framework::OperatorWithKernel<
operators::ConvAddAddPReluKernel<DeviceType, T>>(type, inputs, outputs, DeviceType, FusionConvAddAddPReluParam<DeviceType>,
attrs, scope) {} operators::ConvAddAddPReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddAddPReluParam<DeviceType>, using framework::OperatorWithKernel<
operators::ConvAddAddPReluKernel<DeviceType, T>>::OperatorWithKernel; DeviceType, FusionConvAddAddPReluParam<DeviceType>,
void InferShape() const override; operators::ConvAddAddPReluKernel<DeviceType, T>>::OperatorWithKernel;
protected: void InferShape() const override;
};
protected:
};
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_ADD_PRELU_REGISTER #ifndef CONV_ADD_ADD_PRELU_REGISTER
#define CONV_ADD_ADD_PRELU_REGISTER #define CONV_ADD_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar( static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar(
new FusionConvAddAddPReluOpMatcher()); new FusionConvAddAddPReluOpMatcher());
#endif #endif
#endif #endif
...@@ -87,7 +89,7 @@ namespace paddle_mobile { ...@@ -87,7 +89,7 @@ namespace paddle_mobile {
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
#ifndef CONV_ADD_ADD_PRELU_REGISTER #ifndef CONV_ADD_ADD_PRELU_REGISTER
#define CONV_ADD_ADD_PRELU_REGISTER #define CONV_ADD_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar( static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar(
new FusionConvAddAddPReluOpMatcher()); new FusionConvAddAddPReluOpMatcher());
...@@ -95,7 +97,7 @@ static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar( ...@@ -95,7 +97,7 @@ static framework::FusionOpRegistrar fusion_conv_add_add_prelu_registrar(
#endif #endif
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
......
...@@ -18,38 +18,38 @@ limitations under the License. */ ...@@ -18,38 +18,38 @@ limitations under the License. */
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename Dtype, typename T> template <typename Dtype, typename T>
void FusionConvAddPReluOp<Dtype, T>::InferShape() const { void FusionConvAddPReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims(); auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims(); auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides(); const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings(); std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups(); int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations(); std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() && dilations.size() == paddings.size() &&
paddings.size() == strides.size()), paddings.size() == strides.size()),
"ConvParam is not suitable"); "ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back( output_shape.push_back(
math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i],
paddings[i], strides[i])); paddings[i], strides[i]));
} }
framework::DDim ddim = framework::make_ddim(output_shape); framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim); this->param_.Output()->Resize(ddim);
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_add_prelu,ops::FusionConvAddPReluOp); REGISTER_OPERATOR_CPU(fusion_conv_add_prelu, ops::FusionConvAddPReluOp);
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
#endif #endif
......
...@@ -24,59 +24,59 @@ limitations under the License. */ ...@@ -24,59 +24,59 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
class FusionConvAddPReluOpMatcher : public framework::FusionOpMatcher { class FusionConvAddPReluOpMatcher : public framework::FusionOpMatcher {
public: public:
FusionConvAddPReluOpMatcher() { FusionConvAddPReluOpMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV); node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) > node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_PRELU); std::make_shared<framework::Node>(G_OP_TYPE_PRELU);
} }
void FolderNodes( void FolderNodes(
framework::Node *node, framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) { std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(), node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}} {G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}}
}, },
removed_nodes);
removed_nodes); }
} std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_PRELU; }
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_PRELU; } };
};
template <typename DeviceType, typename T>
template <typename DeviceType, typename T> class FusionConvAddPReluOp
class FusionConvAddPReluOp : public framework::OperatorWithKernel< : public framework::OperatorWithKernel<
DeviceType, FusionConvAddPReluParam<DeviceType>, DeviceType, FusionConvAddPReluParam<DeviceType>,
operators::ConvAddPReluKernel<DeviceType, T>> { operators::ConvAddPReluKernel<DeviceType, T>> {
public: public:
FusionConvAddPReluOp(const string &type, const VariableNameMap &inputs, FusionConvAddPReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionConvAddPReluParam<DeviceType>, DeviceType, FusionConvAddPReluParam<DeviceType>,
operators::ConvAddPReluKernel<DeviceType, T>>(type, inputs, outputs, operators::ConvAddPReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionConvAddPReluParam<DeviceType>, DeviceType, FusionConvAddPReluParam<DeviceType>,
operators::ConvAddPReluKernel<DeviceType, T>>::OperatorWithKernel; operators::ConvAddPReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
protected: protected:
}; };
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_PRELU_REGISTER #ifndef CONV_ADD_PRELU_REGISTER
#define CONV_ADD_PRELU_REGISTER #define CONV_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar( static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar(
new FusionConvAddPReluOpMatcher()); new FusionConvAddPReluOpMatcher());
#endif #endif
#endif #endif
...@@ -84,7 +84,7 @@ namespace paddle_mobile { ...@@ -84,7 +84,7 @@ namespace paddle_mobile {
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
#ifndef CONV_ADD_PRELU_REGISTER #ifndef CONV_ADD_PRELU_REGISTER
#define CONV_ADD_PRELU_REGISTER #define CONV_ADD_PRELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar( static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar(
new FusionConvAddPReluOpMatcher()); new FusionConvAddPReluOpMatcher());
...@@ -92,7 +92,7 @@ static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar( ...@@ -92,7 +92,7 @@ static framework::FusionOpRegistrar fusion_conv_add_prelu_registrar(
#endif #endif
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
......
...@@ -18,21 +18,22 @@ limitations under the License. */ ...@@ -18,21 +18,22 @@ limitations under the License. */
#include "operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h" #include "operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddAddPReluKernel<CPU, float>::Init(FusionConvAddAddPReluParam<CPU> *param) { bool ConvAddAddPReluKernel<CPU, float>::Init(
return true; FusionConvAddAddPReluParam<CPU> *param) {
} return true;
}
template <>
void ConvAddAddPReluKernel<CPU, float>::Compute( template <>
const FusionConvAddAddPReluParam<CPU> &param) const { void ConvAddAddPReluKernel<CPU, float>::Compute(
ConvAddAddPReluCompute<float>(param); const FusionConvAddAddPReluParam<CPU> &param) const {
} ConvAddAddPReluCompute<float>(param);
template class ConvAddAddPReluKernel<CPU, float>; }
template class ConvAddAddPReluKernel<CPU, float>;
} // namespace operators
} // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
...@@ -18,21 +18,21 @@ limitations under the License. */ ...@@ -18,21 +18,21 @@ limitations under the License. */
#include "operators/kernel/central-arm-func/conv_add_prelu_arm_func.h" #include "operators/kernel/central-arm-func/conv_add_prelu_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddPReluKernel<CPU, float>::Init(FusionConvAddPReluParam<CPU> *param) { bool ConvAddPReluKernel<CPU, float>::Init(FusionConvAddPReluParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ConvAddPReluKernel<CPU, float>::Compute( void ConvAddPReluKernel<CPU, float>::Compute(
const FusionConvAddPReluParam<CPU> &param) const { const FusionConvAddPReluParam<CPU> &param) const {
ConvAddPReluCompute<float>(param); ConvAddPReluCompute<float>(param);
} }
template class ConvAddPReluKernel<CPU, float>; template class ConvAddPReluKernel<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
...@@ -23,115 +23,118 @@ limitations under the License. */ ...@@ -23,115 +23,118 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) { void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
Tensor bias1 = *param.Bias1(); Tensor bias1 = *param.Bias1();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations(); std::vector<int> dilations = param.Dilations();
Tensor aa = *param.InputAlpha(); Tensor aa = *param.InputAlpha();
float *p = aa.data<float>(); float *p = aa.data<float>();
std::string mode = param.Mode(); std::string mode = param.Mode();
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups; col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) { for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
} }
framework::DDim col_shape(framework::make_ddim(col_shape_vec)); framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1); framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand = bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations); math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
Tensor col_matrix; Tensor col_matrix;
if (is_expand) { if (is_expand) {
col.mutable_data<float>(col_shape); col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size())); input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
output->dims()[1], output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])}; output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm // convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups; int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col; math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor bias1_batch = bias1.Slice(i,i+1).Resize(output_matrix_shape); Tensor bias1_batch = bias1.Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) { if (!is_expand) {
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) { } else if (data_dim == 2U) {
// im2col // im2col
im2col(in_slice, dilations, strides, im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// vol2col // vol2col
vol2col(in_slice, dilations, strides, paddings, &col); vol2col(in_slice, dilations, strides, paddings, &col);
} }
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step); Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step);
float *biase_data1 = bias1_slice.data<float>(); float *biase_data1 = bias1_slice.data<float>();
// int n = bias1_slice.dims()[0]; // int n = bias1_slice.dims()[0];
// int m = bias1_slice.dims()[1]; // int m = bias1_slice.dims()[1];
// for(int i=0;i<n*m;i++){ // for(int i=0;i<n*m;i++){
// if(biase_data1[i]!=0) // if(biase_data1[i]!=0)
// DLOG<<biase_data1[i]<<",yangfei"; // DLOG<<biase_data1[i]<<",yangfei";
// } // }
// math::matmul<float>(filter_slice, false, col_matrix, false, // math::matmul<float>(filter_slice, false, col_matrix,
// static_cast<float>(1), &out_slice, // false,
// static_cast<float>(1), true, biase_data); // static_cast<float>(1),
math::matmulWithPRelu(filter_slice, false, col_matrix, false, // &out_slice,
&out_slice, p,mode, biase_data,biase_data1); // static_cast<float>(1), true,
} // biase_data);
} math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
} p, mode, biase_data, biase_data1);
}
} // namespace operators }
}
} // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
...@@ -23,105 +23,108 @@ limitations under the License. */ ...@@ -23,105 +23,108 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) { void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
// DLOG<<"yangfei"; // DLOG<<"yangfei";
// DLOG<<bias.dims(); // DLOG<<bias.dims();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations(); std::vector<int> dilations = param.Dilations();
Tensor aa = *param.InputAlpha(); Tensor aa = *param.InputAlpha();
float *p = aa.data<float>(); float *p = aa.data<float>();
std::string mode = param.Mode(); std::string mode = param.Mode();
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups; col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) { for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
} }
framework::DDim col_shape(framework::make_ddim(col_shape_vec)); framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1); framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand = bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations); math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
Tensor col_matrix; Tensor col_matrix;
if (is_expand) { if (is_expand) {
col.mutable_data<float>(col_shape); col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size())); input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
output->dims()[1], output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])}; output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm // convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups; int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col; math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) { if (!is_expand) {
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) { } else if (data_dim == 2U) {
// im2col // im2col
im2col(in_slice, dilations, strides, im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// vol2col // vol2col
vol2col(in_slice, dilations, strides, paddings, &col); vol2col(in_slice, dilations, strides, paddings, &col);
} }
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// math::matmul<float>(filter_slice, false, col_matrix, false, // math::matmul<float>(filter_slice, false, col_matrix,
// static_cast<float>(1), &out_slice, // false,
// static_cast<float>(1), true, biase_data); // static_cast<float>(1),
math::matmulWithPRelu(filter_slice, false, col_matrix, false, // &out_slice,
&out_slice, p,mode, biase_data, nullptr); // static_cast<float>(1), true,
} // biase_data);
} math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
} p, mode, biase_data, nullptr);
}
} // namespace operators }
}
} // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
...@@ -26,20 +26,20 @@ limitations under the License. */ ...@@ -26,20 +26,20 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using framework::DDim; using framework::DDim;
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddAddPReluKernel class ConvAddAddPReluKernel
: public OpKernelBase<DeviceType, FusionConvAddAddPReluParam<DeviceType>> { : public OpKernelBase<DeviceType, FusionConvAddAddPReluParam<DeviceType>> {
public: public:
void Compute(const FusionConvAddAddPReluParam<DeviceType> &param) const; void Compute(const FusionConvAddAddPReluParam<DeviceType> &param) const;
bool Init(FusionConvAddAddPReluParam<DeviceType> *param); bool Init(FusionConvAddAddPReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
...@@ -26,20 +26,20 @@ limitations under the License. */ ...@@ -26,20 +26,20 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using framework::DDim; using framework::DDim;
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddPReluKernel class ConvAddPReluKernel
: public OpKernelBase<DeviceType, FusionConvAddPReluParam<DeviceType>> { : public OpKernelBase<DeviceType, FusionConvAddPReluParam<DeviceType>> {
public: public:
void Compute(const FusionConvAddPReluParam<DeviceType> &param) const; void Compute(const FusionConvAddPReluParam<DeviceType> &param) const;
bool Init(FusionConvAddPReluParam<DeviceType> *param); bool Init(FusionConvAddPReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
...@@ -3172,7 +3172,7 @@ void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3172,7 +3172,7 @@ void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
int max_threads = 1; int max_threads = 1;
#endif #endif
int L1 = 16 / max_threads * 1024; int L1 = 32 * 1024;
KC = k; KC = k;
if (m > n) { if (m > n) {
// 对 A 分块 // 对 A 分块
......
...@@ -110,9 +110,8 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, ...@@ -110,9 +110,8 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP #ifdef _OPENMP
xsSgemmWithPRelu_omp(M, N, K, matrix_a.data<float>(), K, SgemmWithPRelu_omp(M, N, K, matrix_a.data<float>(), K, matrix_b.data<float>(),
matrix_b.data<float>(), N, matrix_out->data<float>(), N, N, matrix_out->data<float>(), N, p, mode, bias, bias1);
p, mode, bias, bias1);
#else #else
SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K, matrix_b.data<float>(), N, SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
matrix_out->data<float>(), N, p, mode, bias, bias1); matrix_out->data<float>(), N, p, mode, bias, bias1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册