/** * \file dnn/src/cuda/warp_perspective/backward_data.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "src/cuda/warp_perspective/opr_impl.h" #include "src/cuda/utils.h" #include "src/cuda/warp_perspective/common.h" #include "src/cuda/warp_perspective/helper.h" namespace megdnn { namespace cuda { WorkspaceBundle WarpPerspectiveBackwardDataImpl::get_workspace_bundle( void* ptr, const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff, const TensorLayout& grad) const { SmallVector sizes; TensorLayout fmat = mat; TensorLayout fdiff = diff; TensorLayout fgrad = grad; auto get_workspace = [&sizes](TensorLayout& layout) { if (layout.dtype == dtype::BFloat16()) { layout.dtype = dtype::Float32(); sizes.push_back(layout.span().dist_byte()); } }; get_workspace(fmat); get_workspace(fdiff); get_workspace(fgrad); sizes.push_back( get_float32_workspace_in_bytes(fmat, mat_idx, fdiff, fgrad)); return {ptr, std::move(sizes)}; } void WarpPerspectiveBackwardDataImpl::exec(_megdnn_tensor_in smat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in sdiff, _megdnn_tensor_out sgrad, _megdnn_workspace sworkspace) { check_exec(smat.layout, mat_idx.layout, sdiff.layout, sgrad.layout, sworkspace.size); TensorND mat = smat; TensorND diff = sdiff; TensorND grad = sgrad; auto bundle = get_workspace_bundle(sworkspace.raw_ptr, smat.layout, mat_idx.layout, sdiff.layout, sgrad.layout); auto ctypecvt = CompTypeCvter( concrete_handle(this->handle()), &bundle); if (sgrad.layout.dtype.enumv() == DTypeTrait::enumv) { ctypecvt.src_to_comp_type(smat, mat) .src_to_comp_type(sdiff, diff) .src_to_comp_type(sgrad, grad); } { auto workspace = ctypecvt.workspace(); auto stream = cuda_stream(this->handle()); auto N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], IW = grad.layout.shape[3], OH = diff.layout.shape[2], OW = diff.layout.shape[3]; int* midx_ptr = nullptr; if (mat_idx.raw_ptr) { megdnn_assert(mat_idx.layout.ndim == 1); N = mat_idx.layout.shape[0]; midx_ptr = mat_idx.ptr(); } else { megdnn_assert(mat_idx.layout.ndim == 0); } auto bval = param().border_val; auto bmode = warp_perspective::get_bmode(param().bmode); size_t batch_x_channel_size = N * C; size_t max_batch_x_channel = max_batch_x_channel_size(); if (batch_x_channel_size <= max_batch_x_channel) { warp_perspective::backward_data_proxy( mat.ptr(), midx_ptr, diff.ptr(), grad.ptr(), reinterpret_cast(workspace.raw_ptr), N, grad.layout.shape[0], C, IH, IW, OH, OW, bval, bmode, stream); } else { dt_float32* mat_ptr = mat.ptr(); dt_float32* diff_ptr = diff.ptr(); dt_float32* grad_ptr = grad.ptr(); size_t max_batch_size = max_batch_x_channel / C; while (N > 0) { size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; warp_perspective::backward_data_proxy( mat_ptr, midx_ptr, diff_ptr, grad_ptr, reinterpret_cast(workspace.raw_ptr), curr_batch_size, grad.layout.shape[0], C, IH, IW, OH, OW, bval, bmode, stream); if (N <= max_batch_size) { break; } else { N -= max_batch_size; mat_ptr += curr_batch_size * mat.layout.stride[0]; diff_ptr += curr_batch_size * diff.layout.stride[0]; if (midx_ptr == nullptr) { grad_ptr += curr_batch_size * grad.layout.stride[0]; } else { midx_ptr += curr_batch_size; } } } } } if (sgrad.layout.dtype.enumv() == DTypeTrait::enumv) { ctypecvt.comp_to_dst_type(grad, sgrad); } } size_t WarpPerspectiveBackwardDataImpl::get_float32_workspace_in_bytes( const TensorLayout& /* mat */, const TensorLayout& mat_idx, const TensorLayout& diff, const TensorLayout& grad) const { auto N = grad.shape[0], C = grad.shape[1], IH = grad.shape[2], IW = grad.shape[3]; auto OH = diff.shape[2], OW = diff.shape[3]; auto bmode = warp_perspective::get_bmode(param().bmode); size_t max_batch_size = N; size_t max_batch_x_channel = max_batch_x_channel_size(); if (N * C > max_batch_x_channel) { /* when batch size is too large, the workspace only contains part of grad, this will cause out of range with mat idx */ megdnn_assert(mat_idx.ndim == 0, "batch size is too large, it's unsupported with mat idx backward."); max_batch_size = max_batch_x_channel / C; } auto res = warp_perspective::get_backward_data_workspace_in_bytes( max_batch_size, C, IH, IW, OH, OW, bmode); return res; } } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen