提交 288148fe 编写于 作者: H hjchen2

Fix bugs in multi-threads sgemm and fuse conv/add/bn/relu

上级 327ca7c6
......@@ -137,14 +137,18 @@ class CLTensor : TensorBase {
: ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
size, reinterpret_cast<void *>(input), NULL)),
size_(size),
capatity_(size),
type_(type),
context_(context),
command_queue_(command_queue) {}
PlaceholderImpl(size_t size, std::type_index type, cl_context context,
cl_command_queue command_queue)
: ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)),
size_(size),
capatity_(size),
type_(type),
context_(context),
command_queue_(command_queue) {}
virtual size_t size() const { return size_; }
......@@ -155,13 +159,25 @@ class CLTensor : TensorBase {
virtual void set_type(std::type_index type) { type_ = type; }
virtual void resize(size_t size) {
if (size > capatity_) {
capatity_ = size;
ptr_.reset(
clCreateBuffer(context_, CL_MEM_READ_WRITE, capatity_, NULL, NULL));
}
size_ = size;
}
std::unique_ptr<_cl_mem, CLMemDeleter> ptr_;
size_t size_;
size_t capatity_;
/* the current type of memory */
std::type_index type_;
cl_context context_;
cl_command_queue command_queue_;
};
};
......
......@@ -68,7 +68,8 @@ struct CPUContext {
};
inline void set_global_num_threads(int threads) {
CPUContext::Context()->set_num_threads(threads);
// CPUContext::Context()->set_num_threads(threads);
CPUContext::Context()->num_threads = threads;
}
inline int get_global_num_threads() {
......
......@@ -30,12 +30,14 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
const Tensor *bias = param->InputBias();
const Tensor *bias1 = param->Bias();
const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
auto bias1_ptr = bias1->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
......@@ -52,7 +54,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
auto new_bias_ptr = new_bias->mutable_data<float>({C});
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] + (bias1_ptr[i] - mean_ptr[i]) *
inv_std_ptr[i] * scale_ptr[i];
}
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
......
......@@ -107,8 +107,8 @@ class GemmExecutor : public Executor {
// gettimeofday(&tv_begin,NULL);
if (M_ > N_) {
int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width();
lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_;
rhs_worksize_ = sizeof(Itype) * K_ * nblock * num_threads_;
lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_ * num_threads_;
rhs_worksize_ = sizeof(Itype) * K_ * nblock;
out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_;
ldc_ = nblock;
} else {
......@@ -133,7 +133,7 @@ class GemmExecutor : public Executor {
if (M_ > N_) {
strategy_.pack_rhs(K_, N_, B, ldb, rhs_workspace_, true);
#pragma omp parallel for if (M_ > 128)
#pragma omp parallel for
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_);
#ifdef _OPENMP
......@@ -165,7 +165,7 @@ class GemmExecutor : public Executor {
} else {
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
#pragma omp parallel for if (N_ > 128)
#pragma omp parallel for
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_);
#ifdef _OPENMP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册