未验证 提交 02ac4e03 编写于 作者: E eclipsycn 提交者: GitHub

Merge branch 'develop' into develop

...@@ -106,11 +106,14 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize( ...@@ -106,11 +106,14 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
} }
std::vector<std::shared_ptr<framework::OpDesc>> op_descs; std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
for (int m = 0; m < nodes.size(); ++m) { if (add_split) {
auto &node = nodes[m]; GenerateOps(&op_descs, begin_node.get(), add_split);
op_descs.push_back(node->op_desc_); } else {
for (int m = 0; m < nodes.size(); ++m) {
auto &node = nodes[m];
op_descs.push_back(node->op_desc_);
}
} }
// GenerateOps(&op_descs, begin_node.get());
block->ops_ = op_descs; block->ops_ = op_descs;
} }
...@@ -267,12 +270,12 @@ void ProgramOptimize::GenerateOps( ...@@ -267,12 +270,12 @@ void ProgramOptimize::GenerateOps(
} }
void ProgramOptimize::GenerateOps( void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_descs, std::vector<std::shared_ptr<framework::OpDesc>> *op_descs, Node *begin_node,
Node *begin_node) { bool can_add_split) {
// std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, // std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
// Node *input_node, Node *current_node, bool adding_thread, int // Node *input_node, Node *current_node, bool adding_thread, int
// thread_num // thread_num
if (false) { if (can_add_split) {
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr); this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr);
} else { } else {
this->GenerateOps(op_descs, begin_node, begin_node); this->GenerateOps(op_descs, begin_node, begin_node);
......
...@@ -34,7 +34,7 @@ class ProgramOptimize { ...@@ -34,7 +34,7 @@ class ProgramOptimize {
int current_block_; int current_block_;
std::vector<std::shared_ptr<BlockDesc>> new_blocks_; std::vector<std::shared_ptr<BlockDesc>> new_blocks_;
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs, void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node); Node *begin_node, bool can_add_split);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *input_node, Node *current_node); Node *input_node, Node *current_node);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
......
...@@ -14,9 +14,11 @@ limitations under the License. */ ...@@ -14,9 +14,11 @@ limitations under the License. */
#include "io.h" #include "io.h"
#include <vector> #include <vector>
#define PADDLE_MOBILE_PROFILE
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
#include <algorithm>
#include <ctime> #include <ctime>
#include <map> #include <unordered_map>
#endif #endif
#include "common/enforce.h" #include "common/enforce.h"
...@@ -74,8 +76,9 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) { ...@@ -74,8 +76,9 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) {
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::Load( const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &dirname, bool optimize) { const std::string &dirname, bool optimize, bool can_add_split) {
auto program = this->LoadProgram(dirname + "/__model__", optimize); auto program =
this->LoadProgram(dirname + "/__model__", optimize, can_add_split);
program.model_path = dirname; program.model_path = dirname;
return program; return program;
} }
...@@ -92,7 +95,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load( ...@@ -92,7 +95,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
const framework::Program<Dtype, P> Loader<Dtype, P>::LoadProgram( const framework::Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
const std::string &model_path, bool optimize) { const std::string &model_path, bool optimize, bool can_add_split) {
std::string model_filename = model_path; std::string model_filename = model_path;
PaddleMobile__Framework__Proto__ProgramDesc *c_program; PaddleMobile__Framework__Proto__ProgramDesc *c_program;
uint8_t *buf = NULL; uint8_t *buf = NULL;
...@@ -144,7 +147,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::LoadProgram( ...@@ -144,7 +147,7 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
if (optimize) { if (optimize) {
framework::ProgramOptimize program_optimize; framework::ProgramOptimize program_optimize;
program.optimizeProgram = program.optimizeProgram =
program_optimize.FushionOptimize(originProgramDesc); program_optimize.FushionOptimize(originProgramDesc, can_add_split);
} }
if (optimize) { if (optimize) {
program.optimizeProgram->Description("optimize: "); program.optimizeProgram->Description("optimize: ");
...@@ -308,6 +311,7 @@ void Executor<Dtype, P>::InitMemory() { ...@@ -308,6 +311,7 @@ void Executor<Dtype, P>::InitMemory() {
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
void Executor<Dtype, P>::InitCombineMemory() { void Executor<Dtype, P>::InitCombineMemory() {
LOG(kLOG_INFO) << " begin init combine memory";
char *origin_data = Get_binary_data(program_.para_path); char *origin_data = Get_binary_data(program_.para_path);
char *data = origin_data; char *data = origin_data;
for (const auto &block : to_predict_program_->Blocks()) { for (const auto &block : to_predict_program_->Blocks()) {
...@@ -328,6 +332,7 @@ void Executor<Dtype, P>::InitCombineMemory() { ...@@ -328,6 +332,7 @@ void Executor<Dtype, P>::InitCombineMemory() {
} }
} }
delete origin_data; delete origin_data;
LOG(kLOG_INFO) << " end init combine memory ";
} }
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
...@@ -341,31 +346,37 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict( ...@@ -341,31 +346,37 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
std::shared_ptr<framework::BlockDesc> to_predict_block = std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0); to_predict_program_->Block(0);
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
std::map<std::string, clock_t> _profile; std::unordered_map<std::string, clock_t> _profile;
#endif #endif
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j]; auto op = ops_of_block_[*to_predict_block.get()][j];
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
_profile[op->Type()] = clock(); _profile[op->Type()] -= clock();
#endif #endif
op->Run(); op->Run();
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
_profile[op->Type()] = clock() - _profile[op->Type()]; _profile[op->Type()] += clock();
#endif #endif
} }
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
{ {
DLOG << "========================[ profile ]=========================="; std::cout << "====================[ profile ]======================\n";
using prof_t = std::pair<std::string, clock_t>;
std::vector<prof_t> _tprofile(_profile.begin(), _profile.end());
clock_t _ptotal = 0; clock_t _ptotal = 0;
for (auto const &p : _profile) { for (auto const &p : _tprofile) {
_ptotal += p.second; _ptotal += p.second;
} }
for (auto const &p : _profile) { auto compf = [](const prof_t &a, const prof_t &b) {
DLOG << p.first << std::string(16 - p.first.size(), ' ') << "\t" return a.second > b.second;
<< (float)p.second << "\t\t" };
<< (float)p.second / (float)_ptotal * 100.0; std::sort(_tprofile.begin(), _tprofile.end(), compf);
_tprofile.push_back(std::make_pair("total", _ptotal));
for (auto const &p : _tprofile) {
printf("%-16s\t%-10.0f\t%-.4f\n", p.first.c_str(), (float)p.second,
(float)p.second / _ptotal * 100.0);
} }
DLOG << "========================[ ]=========================="; std::cout << "====================[---------]======================\n";
} }
#endif #endif
auto ops = ops_of_block_[*to_predict_program_->Block(0)]; auto ops = ops_of_block_[*to_predict_program_->Block(0)];
......
...@@ -35,7 +35,8 @@ class Loader { ...@@ -35,7 +35,8 @@ class Loader {
* @b 加载分开形式的 fluid 模型 * @b 加载分开形式的 fluid 模型
* */ * */
const framework::Program<Dtype, P> Load(const std::string &dirname, const framework::Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false); bool optimize = false,
bool can_add_split = false);
/* /*
* @b load combine format fluid mode * @b load combine format fluid mode
...@@ -47,7 +48,8 @@ class Loader { ...@@ -47,7 +48,8 @@ class Loader {
private: private:
const framework::Program<Dtype, P> LoadProgram(const std::string &model_path, const framework::Program<Dtype, P> LoadProgram(const std::string &model_path,
bool optimize = false); bool optimize = false,
bool can_add_split = false);
}; };
template <typename Dtype = CPU, Precision P = Precision::FP32> template <typename Dtype = CPU, Precision P = Precision::FP32>
......
...@@ -28,7 +28,6 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { ...@@ -28,7 +28,6 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>(); output->mutable_data<float>();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
...@@ -40,7 +39,6 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { ...@@ -40,7 +39,6 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups; col_shape_vec[0] = input->dims()[1] / groups;
...@@ -61,18 +59,13 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { ...@@ -61,18 +59,13 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
// DLOG << " col_shape = " << col_shape;
// DLOG << " col_matrix_shape = " << col_matrix_shape;
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size())); input->dims(), 1, static_cast<int>(input->dims().size()));
// DLOG << " input_shape = " << input_shape;
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
// DLOG << " filter.dims() = " << filter.dims();
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
output->dims()[1], output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])}; output->numel() / (output->dims()[0] * output->dims()[1])};
...@@ -87,8 +80,6 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { ...@@ -87,8 +80,6 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
// DLOG << " in_batch.dims() = " << in_batch.dims();
// DLOG << " out_batch.dims() = " << out_batch.dims();
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
...@@ -111,13 +102,9 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { ...@@ -111,13 +102,9 @@ void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// DLOG << " out_slice " << out_slice.dims();
// DLOG << " filter_slice " << filter_slice.dims();
// DLOG << " col_matrix " << col_matrix.dims();
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(0)); static_cast<float>(0));
auto filter_ptr = filter_slice.data<float>();
} }
} }
} }
......
...@@ -11,29 +11,28 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,29 +11,28 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef TRANSPOSE_OP #ifdef TRANSPOSE_OP
#include "operators/kernel/transpose_kernel.h" #include "operators/kernel/transpose_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename T> // vector<int> pos;
void TransposeFunc(const int numel, const T* input, const vector<int> axis, // template <typename T>
const vector<int> old_strides, const vector<int> new_strides, // void TransposeFunc(const int numel, const T* input, const vector<int> axis,
T* output) { // const vector<int> old_strides, const vector<int>
for (int i = 0; i < numel; ++i) { // new_strides, T* output) {
int old_idx = 0; // for (int i = 0; i < numel; ++i) {
int idx = i; // int old_idx = 0;
for (int j = 0; j < axis.size(); ++j) { // int idx = i;
int order = axis[j]; // for (int j = 0; j < axis.size(); ++j) {
old_idx += (idx / new_strides[j]) * old_strides[order]; // int order = axis[j];
idx %= new_strides[j]; // old_idx += (idx / new_strides[j]) * old_strides[order];
} // idx %= new_strides[j];
output[i] = input[old_idx]; // }
} // output[i] = input[old_idx];
} // }
// }
template <> template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
...@@ -44,28 +43,38 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { ...@@ -44,28 +43,38 @@ void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
const auto* input_x_data = input_x->data<float>(); const auto* input_x_data = input_x->data<float>();
auto* out_data = out->mutable_data<float>(); auto* out_data = out->mutable_data<float>();
size_t axis_size = axis.size(); size_t ndim = axis.size();
std::vector<int> new_dims; std::vector<int> xdim(ndim);
new_dims.reserve(axis_size); std::vector<int> xstride(ndim);
for (auto c : axis) { std::vector<int> xout(ndim);
new_dims.push_back(input_x_dims[c]); for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]];
xstride[j] = 1;
for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k];
}
xout[j] = xstride[j] * xdim[j];
} }
std::vector<int> old_strides; auto numel = input_x->numel();
std::vector<int> new_strides; size_t pind = 0;
for (int i = 0; i < axis.size(); i++) { std::vector<int> ind(ndim);
int temp_old = 1; for (int i = 0; i < numel; i++) {
int temp_new = 1; out_data[i] = input_x_data[pind];
for (int j = i + 1; j < axis.size(); j++) { ind[0]++;
temp_old *= input_x_dims[j]; pind += xstride[0];
temp_new *= new_dims[j]; for (int j = 0; j < ndim - 1; j++) {
if (ind[j] == xdim[j]) {
ind[j + 1]++;
ind[j] = 0;
pind += xstride[j + 1];
pind -= xout[j];
} else {
break;
}
} }
old_strides.push_back(temp_old);
new_strides.push_back(temp_new);
} }
TransposeFunc<float>(input_x->numel(), input_x_data, axis, old_strides,
new_strides, out_data);
} }
} // namespace operators } // namespace operators
......
...@@ -114,10 +114,12 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, ...@@ -114,10 +114,12 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
for (j = 0; j < n - paddingN; j += NR) { for (j = 0; j < n - paddingN; j += NR) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
Bij = &B(i, j); Bij = &B(i, j);
*buffer++ = *Bij; asm volatile(
*buffer++ = *(Bij + 1); "vld1.32 {q0}, [%[Bij]] \n\t"
*buffer++ = *(Bij + 2); "vst1.32 {q0}, [%[buffer]]! \n\t"
*buffer++ = *(Bij + 3); : [buffer] "+r"(buffer)
: [Bij] "r"(Bij)
: "memory", "q0");
} }
} }
if (paddingN != 0) { if (paddingN != 0) {
......
...@@ -20,9 +20,9 @@ limitations under the License. */ ...@@ -20,9 +20,9 @@ limitations under the License. */
#define C(i, j) C[(i)*ldc + (j)] #define C(i, j) C[(i)*ldc + (j)]
// 分块计算的块大小,mc 与 kc 分别对应分块计算时的 m 与 k // 分块计算的块大小,mc 与 kc 分别对应分块计算时的 m 与 k
#define MC 384 #define MC 128
#define KC 384 #define KC 128
#define NC 4096 #define NC 1024
#define MR 4 #define MR 4
#define NR 4 #define NR 4
......
...@@ -19,9 +19,10 @@ int main() { ...@@ -19,9 +19,10 @@ int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader; paddle_mobile::Loader<paddle_mobile::CPU> loader;
// ../../../test/models/googlenet // ../../../test/models/googlenet
// ../../../test/models/mobilenet // ../../../test/models/mobilenet
auto program = loader.Load(g_resnet, true); auto program = loader.Load(g_googlenet, true, true);
loader.Load(g_googlenet_combine + "/model", g_googlenet_combine + "/params", // loader.Load(g_googlenet_combine + "/model", g_googlenet_combine +
true); // "/params",
// true);
program.originProgram->Description("program desc: "); program.originProgram->Description("program desc: ");
return 0; return 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册