提交 9b84dc91 编写于 作者: W Wilber 提交者: GitHub

fix var_conv_2d to support cascading use. test=develop (#2766)

- 修复var_conv_2d级联使用中计算错误的bug
- x86的var_conv_2d中显示指定lod level为3
上级 974c50db
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "lite/backends/cuda/math/gemm.h" #include "lite/backends/cuda/math/gemm.h"
...@@ -38,6 +39,32 @@ inline int ConvOutputSize(int input_size, ...@@ -38,6 +39,32 @@ inline int ConvOutputSize(int input_size,
return output_size; return output_size;
} }
// Eliminate the effects of pad, support batch > 1.
template <typename dtype>
__global__ void eliminate_pad_effect(dtype* src,
const int64_t* offset,
const int num_batch,
const int batch_stride,
const int num_channel,
const int channel_stride,
const int num_height,
const int height_stride,
const int num_width,
const int width_stride,
const int count) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int thread_num = blockDim.x * gridDim.x;
for (tid = threadIdx.x + blockIdx.x * blockDim.x; tid < count;
tid += thread_num) {
int batch_id = tid / batch_stride;
int width_id = tid % num_width;
int cur_len = offset[batch_id + 1] - offset[batch_id];
if (width_id >= cur_len) {
src[tid] = 0.;
}
}
}
void VarConv2DCompute::PrepareForRun() { void VarConv2DCompute::PrepareForRun() {
auto& context = this->ctx_->template As<CUDAContext>(); auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream(); auto stream = context.exec_stream();
...@@ -102,6 +129,46 @@ void VarConv2DCompute::Run() { ...@@ -102,6 +129,46 @@ void VarConv2DCompute::Run() {
conv_param_.output->Resize({output_shape}); conv_param_.output->Resize({output_shape});
conv_impl_->create(conv_param_, &context); conv_impl_->create(conv_param_, &context);
conv_impl_->run(conv_param_); conv_impl_->run(conv_param_);
// Avoid situations where cascading conv does not support multiple batch
// calculations
float* out_data = param.Out->mutable_data<float>();
const int batch_num = output_shape[1] * output_shape[2] * output_shape[3];
std::vector<int64_t> lod(param.X->lod()[0].size(), 0);
for (size_t i = 0; i < param.X->lod()[0].size(); ++i) {
lod[i] = param.X->lod()[0][i];
}
int count = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
int width_stride = 1;
int height_stride = output_shape[3];
int channel_stride = output_shape[2] * output_shape[3];
int batch_stride = output_shape[1] * output_shape[2] * output_shape[3];
int threads = 512;
int blocks = (count + threads - 1) / threads;
offset_.Resize({static_cast<int64_t>(lod.size())});
int64_t* d_offset = offset_.mutable_data<int64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(d_offset,
lod.data(),
sizeof(int64_t) * lod.size(),
IoDirection::HtoD,
stream);
eliminate_pad_effect<float><<<blocks, threads, 0, stream>>>(out_data,
d_offset,
output_shape[0],
batch_stride,
output_shape[1],
channel_stride,
output_shape[2],
height_stride,
output_shape[3],
width_stride,
count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
} }
} // namespace cuda } // namespace cuda
......
...@@ -33,6 +33,7 @@ class VarConv2DCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -33,6 +33,7 @@ class VarConv2DCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
private: private:
mutable operators::ConvParam conv_param_; mutable operators::ConvParam conv_param_;
std::unique_ptr<lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>> conv_impl_; std::unique_ptr<lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>> conv_impl_;
lite::Tensor offset_;
}; };
} // namespace cuda } // namespace cuda
......
...@@ -44,6 +44,7 @@ class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -44,6 +44,7 @@ class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
// 2-D lod info. // 2-D lod info.
// const auto& offset_x = in_col->lod()[0]; // const auto& offset_x = in_col->lod()[0];
// const auto& offset_y = in_row->lod()[0]; // const auto& offset_y = in_row->lod()[0];
CHECK_EQ(param.X->lod().size(), 3) << "input lod size should be 3!";
const auto& offset_y = param.X->lod()[1]; const auto& offset_y = param.X->lod()[1];
const auto& offset_x = param.X->lod()[2]; const auto& offset_x = param.X->lod()[2];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册