提交 680f8b42 编写于 作者: L Liangliang He

Merge branch 'd2s-bug' into 'master'

Fix depth_to_space opencl global work size bug.

See merge request !386
......@@ -108,11 +108,11 @@ struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> {
: block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
const int block_size_;
bool d2s_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
const int block_size_;
bool d2s_;
std::vector<index_t> input_shape_;
};
......
......@@ -22,16 +22,31 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
const char *kernel_name = nullptr;
uint32_t gws[3];
std::stringstream ss;
index_t output_height, output_width, output_depth;
if (d2s_) { output_height = input_height * block_size_;
if (d2s_) {
output_height = input_height * block_size_;
output_width = input_width * block_size_;
output_depth = input_depth / (block_size_ * block_size_);
kernel_name = "depth_to_space";
gws[0] = static_cast<uint32_t>(RoundUpDiv4(output_depth));
gws[1] = static_cast<uint32_t>(output_width);
gws[2] = static_cast<uint32_t>(output_height * batch);
ss << "depth_to_space_opencl_kernel_" << batch << "_"
<< output_height << "_" << output_width << "_" << output_depth;
} else {
output_height = input_height / block_size_;
output_width = input_width / block_size_;
output_depth = input_depth * block_size_ * block_size_;
kernel_name = "space_to_depth";
gws[0] = static_cast<uint32_t>(RoundUpDiv4(input_depth));
gws[1] = static_cast<uint32_t>(input_width);
gws[2] = static_cast<uint32_t>(input_height * batch);
ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_"
<< input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3);
}
const index_t input_depth_blocks = RoundUpDiv4(input_depth);
const index_t output_depth_blocks = RoundUpDiv4(output_depth);
......@@ -73,23 +88,7 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
uint32_t gws[3];
std::stringstream ss;
if (!IsVecEqual(input_shape_, input->shape())) {
if (d2s_) {
gws[0] = static_cast<uint32_t>(output_depth_blocks);
gws[1] = static_cast<uint32_t>(output_width);
gws[2] = static_cast<uint32_t>(output_height * batch);
ss << "depth_to_space_opencl_kernel_" << output->dim(0) << "_"
<< output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3);
} else {
gws[0] = static_cast<uint32_t>(input_depth_blocks);
gws[1] = static_cast<uint32_t>(input_width);
gws[2] = static_cast<uint32_t>(input_height * batch);
ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_"
<< input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3);
}
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_.setArg(idx++,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册