未验证 提交 4ae23cc3 编写于 作者: Y Yibing Liu 提交者: GitHub

Impl fp16 compute kernel for slice_op (#16206)

* Impl fp16 compute kernel for slice_op

test=develop

* Use data() to replace mutable_data()
上级 0ca6465a
...@@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <size_t D>
__global__ void Padding(const paddle::platform::float16* d_out,
const int* out_dims, const int* in_dims,
const int* offsets, int64_t n,
paddle::platform::float16* d_in) {
int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (out_idx < n) {
int coords[D] = {0};
for (int i = D - 1; i >= 0; --i) {
coords[i] = out_idx % out_dims[i];
out_idx /= out_dims[i];
coords[i] += offsets[i];
}
int64_t in_idx = 0;
for (int i = 0; i < D - 1; ++i) {
in_idx += coords[i] * in_dims[i + 1];
}
in_idx += coords[D - 1];
d_in[in_idx] = d_out[out_idx];
}
}
template <>
class SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>
: public framework::OpKernel<paddle::platform::float16> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input"));
d_in->mutable_data<paddle::platform::float16>(ctx.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_in->dims();
int rank = out_dims.size();
std::vector<int> offsets(rank, 0);
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
for (size_t i = 0; i < starts.size(); ++i) {
if (starts[i] < 0) {
starts[i] += in_dims[axes[i]];
}
offsets[axes[i]] = std::max(starts[i], 0);
}
math::SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::float16>
set_zero;
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
set_zero(dev_ctx, d_in, static_cast<paddle::platform::float16>(0));
int64_t numel = d_out->numel();
dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1, 1, 1);
dim3 threads(PADDLE_CUDA_NUM_THREADS, 1, 1);
auto stream = ctx.cuda_device_context().stream();
auto out_shape = framework::vectorize2int(out_dims);
thrust::device_vector<int> out_dims_vec(out_shape.begin(), out_shape.end());
auto in_shape = framework::vectorize2int(in_dims);
thrust::device_vector<int> in_dims_vec(in_shape.begin(), in_shape.end());
thrust::device_vector<int> offsets_vec(offsets.begin(), offsets.end());
const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());
switch (rank) {
case 1:
Padding<1><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 2:
Padding<2><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 3:
Padding<3><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 4:
Padding<4><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 5:
Padding<5><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 6:
Padding<6><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>, slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>, ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>, ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slice_grad, slice_grad,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>, ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -63,5 +64,28 @@ class TestCase2(TestSliceOp): ...@@ -63,5 +64,28 @@ class TestCase2(TestSliceOp):
self.out = self.input[-3:3, 0:100, :, 2:-1] self.out = self.input[-3:3, 0:100, :, 2:-1]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16(TestSliceOp):
def config(self):
self.dtype = "float16"
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.out = self.input[-3:3, 0:100, :, 2:-1]
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-5)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['Input'], 'Out', max_relative_error=0.006)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册