提交 288148fe 编写于 作者: H hjchen2

Fix bugs in multi-threads sgemm and fuse conv/add/bn/relu

上级 327ca7c6
...@@ -137,14 +137,18 @@ class CLTensor : TensorBase { ...@@ -137,14 +137,18 @@ class CLTensor : TensorBase {
: ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, : ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
size, reinterpret_cast<void *>(input), NULL)), size, reinterpret_cast<void *>(input), NULL)),
size_(size), size_(size),
capatity_(size),
type_(type), type_(type),
context_(context),
command_queue_(command_queue) {} command_queue_(command_queue) {}
PlaceholderImpl(size_t size, std::type_index type, cl_context context, PlaceholderImpl(size_t size, std::type_index type, cl_context context,
cl_command_queue command_queue) cl_command_queue command_queue)
: ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)), : ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)),
size_(size), size_(size),
capatity_(size),
type_(type), type_(type),
context_(context),
command_queue_(command_queue) {} command_queue_(command_queue) {}
virtual size_t size() const { return size_; } virtual size_t size() const { return size_; }
...@@ -155,13 +159,25 @@ class CLTensor : TensorBase { ...@@ -155,13 +159,25 @@ class CLTensor : TensorBase {
virtual void set_type(std::type_index type) { type_ = type; } virtual void set_type(std::type_index type) { type_ = type; }
virtual void resize(size_t size) {
if (size > capatity_) {
capatity_ = size;
ptr_.reset(
clCreateBuffer(context_, CL_MEM_READ_WRITE, capatity_, NULL, NULL));
}
size_ = size;
}
std::unique_ptr<_cl_mem, CLMemDeleter> ptr_; std::unique_ptr<_cl_mem, CLMemDeleter> ptr_;
size_t size_; size_t size_;
size_t capatity_;
/* the current type of memory */ /* the current type of memory */
std::type_index type_; std::type_index type_;
cl_context context_;
cl_command_queue command_queue_; cl_command_queue command_queue_;
}; };
}; };
......
...@@ -68,7 +68,8 @@ struct CPUContext { ...@@ -68,7 +68,8 @@ struct CPUContext {
}; };
inline void set_global_num_threads(int threads) { inline void set_global_num_threads(int threads) {
CPUContext::Context()->set_num_threads(threads); // CPUContext::Context()->set_num_threads(threads);
CPUContext::Context()->num_threads = threads;
} }
inline int get_global_num_threads() { inline int get_global_num_threads() {
......
...@@ -30,12 +30,14 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -30,12 +30,14 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
const Tensor *variance = param->InputVariance(); const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale(); const Tensor *scale = param->InputScale();
const Tensor *bias = param->InputBias(); const Tensor *bias = param->InputBias();
const Tensor *bias1 = param->Bias();
const float epsilon = param->Epsilon(); const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>(); auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>(); auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>(); auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>(); auto bias_ptr = bias->data<float>();
auto bias1_ptr = bias1->data<float>();
const int C = mean->numel(); const int C = mean->numel();
float inv_std_ptr[C]; float inv_std_ptr[C];
...@@ -52,7 +54,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -52,7 +54,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
auto new_bias_ptr = new_bias->mutable_data<float>({C}); auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) { for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; new_bias_ptr[i] = bias_ptr[i] + (bias1_ptr[i] - mean_ptr[i]) *
inv_std_ptr[i] * scale_ptr[i];
} }
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
......
...@@ -107,8 +107,8 @@ class GemmExecutor : public Executor { ...@@ -107,8 +107,8 @@ class GemmExecutor : public Executor {
// gettimeofday(&tv_begin,NULL); // gettimeofday(&tv_begin,NULL);
if (M_ > N_) { if (M_ > N_) {
int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width(); int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width();
lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_; lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_ * num_threads_;
rhs_worksize_ = sizeof(Itype) * K_ * nblock * num_threads_; rhs_worksize_ = sizeof(Itype) * K_ * nblock;
out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_; out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_;
ldc_ = nblock; ldc_ = nblock;
} else { } else {
...@@ -133,7 +133,7 @@ class GemmExecutor : public Executor { ...@@ -133,7 +133,7 @@ class GemmExecutor : public Executor {
if (M_ > N_) { if (M_ > N_) {
strategy_.pack_rhs(K_, N_, B, ldb, rhs_workspace_, true); strategy_.pack_rhs(K_, N_, B, ldb, rhs_workspace_, true);
#pragma omp parallel for if (M_ > 128) #pragma omp parallel for
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_);
#ifdef _OPENMP #ifdef _OPENMP
...@@ -165,7 +165,7 @@ class GemmExecutor : public Executor { ...@@ -165,7 +165,7 @@ class GemmExecutor : public Executor {
} else { } else {
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
#pragma omp parallel for if (N_ > 128) #pragma omp parallel for
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_);
#ifdef _OPENMP #ifdef _OPENMP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册