提交 3addd402 编写于 作者: H hjchen2

Add input channels consideration refered from ncnn

上级 af2a6b22
cmake_minimum_required(VERSION 3.0.0)
option(USE_OPENMP "openmp support" ON)
option(DEBUGING "enable debug mode" ON)
option(USE_EXCEPTION "use std exception" ON)
option(DEBUGING "enable debug mode" OFF)
option(USE_EXCEPTION "use std exception" OFF)
option(SYMBOL_HIDDEN "symbol hidden" OFF) # on when use jni or ios io
option(LOG_PROFILE "log profile" OFF)
# select the platform to build
......@@ -247,6 +247,5 @@ elseif(FPGA)
add_subdirectory(test)
endif()
add_subdirectory(test)
......@@ -30,7 +30,6 @@ limitations under the License. */
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <utility>
#include "common/threadpool.h"
#endif
......@@ -96,13 +95,12 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
}
template <typename Dtype>
static void LoadMemInternal(void **data, framework::LoDTensor *tensor,
bool quant_uint8 = false) {
void LoadMemInternal(void **data, framework::LoDTensor *tensor) {
char **data_buf = reinterpret_cast<char **>(data);
int64_t size = tensor->numel();
Dtype *tensor_data = tensor->mutable_data<Dtype>();
if (quant_uint8) {
// should be moved into operator init function
if (0) {
// TODO(hjchen2) should be moved into operator init function
float min_value;
float max_value;
memory::Copy(&min_value, data_buf, sizeof(float));
......@@ -158,8 +156,7 @@ void Executor<Dtype, P>::LoadMemory(
// parse tensor from stream
switch (tensor_desc.DataType()) {
case framework::VARTYPE_TYPE_FP32:
LoadMemInternal<float>(reinterpret_cast<void **>(data_buf), tensor,
program_.quantification);
LoadMemInternal<float>(reinterpret_cast<void **>(data_buf), tensor);
break;
case framework::VARTYPE_TYPE_INT8:
LoadMemInternal<int8_t>(reinterpret_cast<void **>(data_buf), tensor);
......@@ -266,6 +263,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
framework::Variable *g_feed_value = program_.scope->Var("feed");
framework::Tensor *feed_tensor =
g_feed_value->GetMutable<framework::LoDTensor>();
DLOG << "feed_tensor dim: " << feed_tensor->dims();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
std::shared_ptr<framework::BlockDesc> to_predict_block =
......@@ -300,8 +298,16 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
for (int i = 0; i < profile.size(); i++) {
const auto &pInfo = profile[i];
uint64_t timeCost = pInfo.runEnd - pInfo.runBegin;
if (ops[i]->Type() == "conv2d") {
auto inputs = ops[i]->Inputs();
auto *filter = framework::GetVarValue<framework::LoDTensor>(
"Filter", inputs, *(program_.scope));
int kernel_size = filter->dims()[2];
_tp[ops[i]->Type() + "_" + std::to_string(kernel_size)] += timeCost;
} else {
_tp[ops[i]->Type()] += timeCost;
}
}
printf("====================[ profile ]======================\n");
using prof_t = std::pair<std::string, uint64_t>;
std::vector<prof_t> _tv(_tp.begin(), _tp.end());
......@@ -370,6 +376,14 @@ std::shared_ptr<framework::LoDTensor> Executor<Dtype, P>::PredictLod(
for (int i = 0; i < profile.size(); i++) {
const auto &pInfo = profile[i];
uint64_t timeCost = pInfo.runEnd - pInfo.runBegin;
if (ops[i]->Type() == "conv2d") {
auto inputs = ops[i]->Inputs();
auto input_keys = ops[i]->GetInputKeys();
auto *filter = framework::GetVarValue<framework::LoDTensor>(
input_keys[1], inputs, *(program_.scope));
int kernel_size = filter->dims()[2];
printf("kernel size: %d\n", kernel_size);
}
_tp[ops[i]->Type()] += timeCost;
}
printf("====================[ profile ]======================\n");
......
......@@ -40,7 +40,8 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[2] >= 16) {
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor *transformed_weight = new framework::Tensor;
......
......@@ -3,3 +3,4 @@ set(ANDROID_PIE TRUE)
set(ANDROID_STL "c++_static")
set(ANDROID_PLATFORM "android-22")
include("${CMAKE_CURRENT_LIST_DIR}/../android-cmake/android.toolchain.cmake")
#include("/Users/chenhoujiang/Project/android-ndk-r16b/build/cmake/android.toolchain.cmake")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册